diff --git a/src/device_notification_client.rs b/src/device_notification_client.rs new file mode 100644 index 0000000..b1f9a9d --- /dev/null +++ b/src/device_notification_client.rs @@ -0,0 +1,56 @@ +use windows::{ + Win32::{ + Media::Audio::{ + IMMNotificationClient, IMMNotificationClient_Impl, + }, + }, +}; + +#[windows::core::implement(IMMNotificationClient)] +pub(crate) struct DeviceNotificationClient { + pub(crate) callback: Box +} + +impl Drop for DeviceNotificationClient { + fn drop(&mut self) { + println!("DNC drop"); + } +} + +impl IMMNotificationClient_Impl for DeviceNotificationClient { + fn OnDeviceStateChanged( + &self, + _pwstrdeviceid: &windows::core::PCWSTR, + _dwnewstate: u32, + ) -> windows::core::Result<()> { + Ok(()) + } + + fn OnDeviceAdded(&self, pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> { + (self.callback)(pwstrdeviceid); + Ok(()) + } + + fn OnDeviceRemoved(&self, _pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> { + println!("OnDeviceRemoved"); + // TODO: Remove device and all its sessions + Ok(()) + } + + fn OnDefaultDeviceChanged( + &self, + _flow: windows::Win32::Media::Audio::EDataFlow, + _role: windows::Win32::Media::Audio::ERole, + _pwstrdefaultdeviceid: &windows::core::PCWSTR, + ) -> windows::core::Result<()> { + Ok(()) + } + + fn OnPropertyValueChanged( + &self, + _pwstrdeviceid: &windows::core::PCWSTR, + _key: &windows::Win32::UI::Shell::PropertiesSystem::PROPERTYKEY, + ) -> windows::core::Result<()> { + Ok(()) + } +} diff --git a/src/main.rs b/src/main.rs index c4325bc..42f5053 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,9 @@ +mod device_notification_client; mod pid_to_exe; +mod session_notification; +mod sm_session_notifier; +mod window_change; - -use once_cell::sync::Lazy; use std::{ collections::HashSet, error::Error, @@ -10,174 +12,25 @@ use std::{ io::{BufRead, BufReader}, path::Path, ptr::null_mut, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, mpsc::{self, Sender, Receiver}}, thread, }; -use widestring::U16CStr; +use sm_session_notifier::SMSessionNotifier; +use window_change::{win_event_hook_loop, WIN_CHANGE_CHANNEL_TX}; use windows::{ - core::{Interface, PCWSTR}, + core::Interface, Win32::{ - Foundation::{HINSTANCE, HWND}, - Media::Audio::{ - eRender, IAudioSessionControl, IAudioSessionControl2, IAudioSessionManager2, - IAudioSessionNotification, IAudioSessionNotification_Impl, IMMDeviceEnumerator, - IMMNotificationClient, IMMNotificationClient_Impl, ISimpleAudioVolume, - MMDeviceEnumerator, DEVICE_STATE_ACTIVE, - }, - System::Com::{CoCreateInstance, CoInitializeEx, CLSCTX_ALL, COINIT_MULTITHREADED, STGM_READ}, - UI::{ - Accessibility::{SetWinEventHook, HWINEVENTHOOK}, - WindowsAndMessaging::{ - DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId, - TranslateMessage, EVENT_SYSTEM_FOREGROUND, EVENT_SYSTEM_MINIMIZEEND, MSG, - WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS, - }, - }, Devices::FunctionDiscovery::PKEY_Device_FriendlyName, + Media::Audio::{IAudioSessionControl2, ISimpleAudioVolume}, + System::Com::{CoInitializeEx, COINIT_MULTITHREADED}, }, }; use crate::pid_to_exe::pid_to_exe_path; -unsafe extern "system" fn win_event_proc( - _hook: HWINEVENTHOOK, - event: u32, - _hwnd: HWND, - _id_object: i32, - _id_child: i32, - _dw_event_thread: u32, - _dwms_event_time: u32, -) { - if event == EVENT_SYSTEM_FOREGROUND || event == EVENT_SYSTEM_MINIMIZEEND { - let mut pid: u32 = 0; - // Instead of using the hwnd passed to us in the wineventproc, call GetForegroundWindow again because Hollow Knight does something weird - // and immediately fires again with the hwnd for explorer? - let hwnd = GetForegroundWindow(); - GetWindowThreadProcessId(hwnd, Some(&mut pid)); - pid_to_exe_path(pid) - .map(|path| MUTER.lock().unwrap().notify_window_changed(&path)) - .unwrap_or_else(|err| { - println!( - "Error finding process with pid {} for hwnd: {:?}: {:?}", - pid, hwnd, err - ) - }); - } -} - -fn win_event_hook_loop() { - unsafe { - SetWinEventHook( - EVENT_SYSTEM_FOREGROUND, - EVENT_SYSTEM_FOREGROUND, - HINSTANCE::default(), - Some(win_event_proc), - 0, - 0, - WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS, - ); - SetWinEventHook( - EVENT_SYSTEM_MINIMIZEEND, - EVENT_SYSTEM_MINIMIZEEND, - HINSTANCE::default(), - Some(win_event_proc), - 0, - 0, - WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS, - ); - - let mut msg: MSG = MSG::default(); - while GetMessageW(&mut msg, HWND::default(), 0, 0).as_bool() { - TranslateMessage(&msg); - DispatchMessageW(&msg); - } - } -} - -#[windows::core::implement(IAudioSessionNotification)] -struct SessionNotification {} - -impl IAudioSessionNotification_Impl for SessionNotification { - fn OnSessionCreated( - self: &SessionNotification, - newsession: &core::option::Option, - ) -> windows::core::Result<()> { - let ses: IAudioSessionControl2 = newsession.as_ref().unwrap().cast().unwrap(); - MUTER - .lock() - .unwrap() - .add_session(ses) - .unwrap(); - Ok(()) - } -} - -impl Drop for SessionNotification { - fn drop(&mut self) { - println!("SN drop"); - } -} - -#[windows::core::implement(IMMNotificationClient)] -struct DeviceNotificationClient {} - -impl Drop for DeviceNotificationClient { - fn drop(&mut self) { - println!("DNC drop"); - } -} - -impl IMMNotificationClient_Impl for DeviceNotificationClient { - fn OnDeviceStateChanged( - &self, - _pwstrdeviceid: &windows::core::PCWSTR, - _dwnewstate: u32, - ) -> windows::core::Result<()> { - Ok(()) - } - - fn OnDeviceAdded(&self, pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> { - MUTER - .lock() - .unwrap() - .add_device_by_id(pwstrdeviceid) - .unwrap_or_else(|error| { - println!("Error adding device: {:?}", error); - }); - Ok(()) - } - - fn OnDeviceRemoved(&self, _pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> { - println!("OnDeviceRemoved"); - // TODO: Remove device and all its sessions - Ok(()) - } - - fn OnDefaultDeviceChanged( - &self, - _flow: windows::Win32::Media::Audio::EDataFlow, - _role: windows::Win32::Media::Audio::ERole, - _pwstrdefaultdeviceid: &windows::core::PCWSTR, - ) -> windows::core::Result<()> { - Ok(()) - } - - fn OnPropertyValueChanged( - &self, - _pwstrdeviceid: &windows::core::PCWSTR, - _key: &windows::Win32::UI::Shell::PropertiesSystem::PROPERTYKEY, - ) -> windows::core::Result<()> { - Ok(()) - } -} - struct SessionMuter { - sessions: Arc>>, - device_enumerator: IMMDeviceEnumerator, - device_notification_client: IMMNotificationClient, - session_notification: IAudioSessionNotification, + sessions: Vec, mute_executables: HashSet, mute_flag: bool, - session_managers: Vec, } fn load_mute_txt() -> HashSet { @@ -187,15 +40,8 @@ fn load_mute_txt() -> HashSet { impl SessionMuter { fn new() -> SessionMuter { - let s = Arc::new(Mutex::new(Vec::new())); SessionMuter { - session_managers: Vec::new(), - sessions: s.clone(), - device_enumerator: unsafe { - CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL).unwrap() - }, - device_notification_client: IMMNotificationClient::from(DeviceNotificationClient {}), - session_notification: IAudioSessionNotification::from(SessionNotification {}), + sessions: Vec::new(), mute_executables: load_mute_txt(), mute_flag: true, } @@ -212,28 +58,26 @@ impl SessionMuter { } match self.session_to_filename(&session) { Ok(file_name) => { - let fn_str = file_name - .to_string_lossy() - .to_string(); + let fn_str = file_name.to_string_lossy().to_string(); if self.mute_executables.contains(&fn_str) { println!("Adding session from: {:?}", fn_str); unsafe { let volume: ISimpleAudioVolume = session.cast()?; volume.SetMute(self.mute_flag, null_mut())?; } - self.sessions.lock().unwrap().push(session); + self.sessions.push(session); } - }, + } Err(err) => { println!("Unable to get filename for session {:?}", err); - }, + } } Ok(()) } unsafe fn set_mute_all(self: &mut SessionMuter, mute: bool) { - let sessions = self.sessions.lock().unwrap(); - let results = sessions + let results = self + .sessions .iter() .map(|session_control2| session_control2.cast::()) .map(|vol_result| vol_result.map(|volume| volume.SetMute(mute, null_mut()))); @@ -242,45 +86,21 @@ impl SessionMuter { } } - unsafe fn notify_window_changed(self: &mut SessionMuter, path: &str) { - let binding = Path::new(path) - .file_name() - .expect("failed to extract filename from path"); - let file_name = binding.to_os_string().to_string_lossy().to_string(); - let mute_flag = !self.mute_executables.contains(&file_name); - if mute_flag != self.mute_flag { - self.mute_flag = mute_flag; - self.set_mute_all(self.mute_flag); - println!( - "Mute set to {} due to foreground window: {}", - self.mute_flag, file_name - ); - } - } - - fn boot_devices(self: &mut SessionMuter) -> Result<(), Box> { + fn notify_window_changed(self: &mut SessionMuter, path: &str) { unsafe { - self.device_enumerator - .RegisterEndpointNotificationCallback(&self.device_notification_client)?; - let device_collection = self - .device_enumerator - .EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE)?; - for device in (0..device_collection.GetCount()?).map(|x| device_collection.Item(x)) { - let mmdevice = device?; - self.add_device(mmdevice)?; + let binding = Path::new(path) + .file_name() + .expect("failed to extract filename from path"); + let file_name = binding.to_os_string().to_string_lossy().to_string(); + let mute_flag = !self.mute_executables.contains(&file_name); + if mute_flag != self.mute_flag { + self.mute_flag = mute_flag; + self.set_mute_all(self.mute_flag); + println!( + "Mute set to {} due to foreground window: {}", + self.mute_flag, file_name + ); } - println!("All devices initialized."); - return Ok(()); - } - } - - fn add_device_by_id(self: &mut SessionMuter, id: &PCWSTR) -> Result<(), Box> { - unsafe { - let device_enumerator: IMMDeviceEnumerator = - CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)?; - let device = device_enumerator.GetDevice(id)?; - self.add_device(device)?; - Ok(()) } } @@ -297,45 +117,35 @@ impl SessionMuter { return Ok(file_name.to_os_string()); } } - - unsafe fn add_device( - self: &mut SessionMuter, - device: windows::Win32::Media::Audio::IMMDevice, - ) -> Result<(), Box> { - let sm: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?; - sm.RegisterSessionNotification(&self.session_notification)?; - let session_enumerator = sm.GetSessionEnumerator()?; - let session_count = session_enumerator.GetCount()?; - let device_sessions = (0..session_count) - .map(|idx| { - session_enumerator - .GetSession(idx) - .and_then(|session| session.cast()) - }) - .collect::, _>>()?; - for session in device_sessions { - self.add_session(session)?; - } - let prop_store = device.OpenPropertyStore(STGM_READ)?; - // this definition isn't in windows-rs yet apparently :( - let prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?; - let str = U16CStr::from_ptr_str(prop_var.Anonymous.Anonymous.Anonymous.pwszVal.0).to_string_lossy(); - println!("Device Added: {} Existing Sessions: {}", str, session_count); - self.session_managers.push(sm); - Ok(()) - } } -unsafe impl Send for SessionMuter {} -unsafe impl Sync for SessionMuter {} - -static MUTER: Lazy>> = - Lazy::new(|| Arc::new(Mutex::new(SessionMuter::new()))); - fn main() { unsafe { CoInitializeEx(None, COINIT_MULTITHREADED).unwrap(); } - MUTER.lock().unwrap().boot_devices().expect("failed to initialize devices"); - win_event_hook_loop(); + let muter = Arc::new(Mutex::new(SessionMuter::new())); + let muter2 = muter.clone(); + let sn = SMSessionNotifier::new(Box::new(move |session| { + muter2.lock().unwrap().add_session(session).unwrap() + })); + + sn.lock() + .unwrap() + .as_mut() + .unwrap() + .boot_devices() + .expect("failed to get initial audio devices and sessions"); + + let (tx, rx) : (Sender, Receiver) = mpsc::channel(); + *WIN_CHANGE_CHANNEL_TX.lock().unwrap().borrow_mut() = Some(tx); + + thread::spawn(move || { + win_event_hook_loop(); + }); + + loop { + let path = rx.recv().unwrap(); + muter.lock().unwrap().notify_window_changed(&path); + } + } diff --git a/src/session_notification.rs b/src/session_notification.rs new file mode 100644 index 0000000..ccc0e57 --- /dev/null +++ b/src/session_notification.rs @@ -0,0 +1,31 @@ +use windows::{ + core::{Interface}, + Win32::{ + Media::Audio::{ + IAudioSessionControl, IAudioSessionControl2, + IAudioSessionNotification, IAudioSessionNotification_Impl, + }, + }, +}; + +#[windows::core::implement(IAudioSessionNotification)] +pub(crate) struct SessionNotification { + pub(crate) callback: Box +} + +impl IAudioSessionNotification_Impl for SessionNotification { + fn OnSessionCreated( + self: &SessionNotification, + newsession: &core::option::Option, + ) -> windows::core::Result<()> { + let ses: IAudioSessionControl2 = newsession.as_ref().unwrap().cast().unwrap(); + (self.callback)(ses); + Ok(()) + } +} + +impl Drop for SessionNotification { + fn drop(&mut self) { + println!("SN drop"); + } +} diff --git a/src/sm_session_notifier.rs b/src/sm_session_notifier.rs new file mode 100644 index 0000000..cd6822e --- /dev/null +++ b/src/sm_session_notifier.rs @@ -0,0 +1,116 @@ +use std::{ + error::Error, + sync::{Arc, Mutex}, +}; + +use crate::{ + device_notification_client::DeviceNotificationClient, session_notification::SessionNotification, +}; +use widestring::U16CStr; +use windows::{ + core::{Interface, PCWSTR}, + Win32::{ + Devices::FunctionDiscovery::PKEY_Device_FriendlyName, + Media::Audio::{ + eRender, IAudioSessionControl2, IAudioSessionManager2, IAudioSessionNotification, + IMMDeviceEnumerator, IMMNotificationClient, MMDeviceEnumerator, DEVICE_STATE_ACTIVE, + }, + System::Com::{CoCreateInstance, CLSCTX_ALL, STGM_READ}, + }, +}; +pub(crate) struct SMSessionNotifier { + device_enumerator: IMMDeviceEnumerator, + device_notification_client: IMMNotificationClient, + session_notification: IAudioSessionNotification, + session_managers: Vec, + notification_function: Box, +} + +impl SMSessionNotifier { + pub(crate) fn new( + callback: Box, + ) -> Arc>> { + let session_notifier_arc: Arc>> = + Arc::new(Mutex::new(None)); + let sn2 = session_notifier_arc.clone(); + let sn3 = session_notifier_arc.clone(); + let notifier = SMSessionNotifier { + session_managers: Vec::new(), + device_enumerator: unsafe { + CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL).unwrap() + }, + device_notification_client: IMMNotificationClient::from(DeviceNotificationClient { + callback: Box::new(move |id| { + match sn2.lock().unwrap().as_mut().unwrap().add_device_by_id(id) { + Ok(_) => {}, + Err(e) => { + println!("Failed to add new device: {:?}", e); + }, + } + }), + }), + session_notification: IAudioSessionNotification::from(SessionNotification { + callback: Box::new(move |session| { + (sn3.lock().unwrap().as_mut().unwrap().notification_function)(session) + }), + }), + notification_function: callback, + }; + *session_notifier_arc.lock().unwrap() = Some(notifier); + return session_notifier_arc.clone(); + } + + pub(crate) fn boot_devices(self: &mut SMSessionNotifier) -> Result<(), Box> { + unsafe { + self.device_enumerator + .RegisterEndpointNotificationCallback(&self.device_notification_client)?; + let device_collection = self + .device_enumerator + .EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE)?; + for device in (0..device_collection.GetCount()?).map(|x| device_collection.Item(x)) { + let mmdevice = device?; + self.add_device(mmdevice)?; + } + println!("All devices initialized."); + return Ok(()); + } + } + + fn add_device_by_id(self: &mut SMSessionNotifier, id: &PCWSTR) -> Result<(), Box> { + unsafe { + let device_enumerator: IMMDeviceEnumerator = + CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)?; + let device = device_enumerator.GetDevice(id)?; + self.add_device(device)?; + Ok(()) + } + } + + unsafe fn add_device( + self: &mut SMSessionNotifier, + device: windows::Win32::Media::Audio::IMMDevice, + ) -> Result<(), Box> { + let sm: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?; + sm.RegisterSessionNotification(&self.session_notification)?; + let session_enumerator = sm.GetSessionEnumerator()?; + let session_count = session_enumerator.GetCount()?; + let device_sessions = (0..session_count) + .map(|idx| { + session_enumerator + .GetSession(idx) + .and_then(|session| session.cast()) + }) + .collect::, _>>()?; + for session in device_sessions { + (self.notification_function)(session); + } + let prop_store = device.OpenPropertyStore(STGM_READ)?; + // this definition isn't in windows-rs yet apparently :( + let prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?; + let str = U16CStr::from_ptr_str(prop_var.Anonymous.Anonymous.Anonymous.pwszVal.0) + .to_string_lossy(); + println!("Device Added: {} Existing Sessions: {}", str, session_count); + self.session_managers.push(sm); + Ok(()) + } +} diff --git a/src/window_change.rs b/src/window_change.rs new file mode 100644 index 0000000..c57cdb3 --- /dev/null +++ b/src/window_change.rs @@ -0,0 +1,86 @@ +use std::{ + cell::RefCell, + sync::{mpsc::Sender, Arc, Mutex}, +}; + +use once_cell::sync::Lazy; +use windows::Win32::{ + Foundation::{HINSTANCE, HWND}, + UI::{ + Accessibility::{SetWinEventHook, HWINEVENTHOOK}, + WindowsAndMessaging::{ + DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId, + TranslateMessage, EVENT_SYSTEM_FOREGROUND, EVENT_SYSTEM_MINIMIZEEND, MSG, + WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS, + }, + }, +}; + +use crate::pid_to_exe::pid_to_exe_path; + +pub static WIN_CHANGE_CHANNEL_TX: Lazy>>>>> = + Lazy::new(|| Arc::new(Mutex::new(RefCell::new(None)))); + +unsafe extern "system" fn win_event_proc( + _hook: HWINEVENTHOOK, + event: u32, + _hwnd: HWND, + _id_object: i32, + _id_child: i32, + _dw_event_thread: u32, + _dwms_event_time: u32, +) { + if event == EVENT_SYSTEM_FOREGROUND || event == EVENT_SYSTEM_MINIMIZEEND { + let mut pid: u32 = 0; + // Instead of using the hwnd passed to us in the wineventproc, call GetForegroundWindow again because Hollow Knight does something weird + // and immediately fires again with the hwnd for explorer? + let hwnd = GetForegroundWindow(); + GetWindowThreadProcessId(hwnd, Some(&mut pid)); + pid_to_exe_path(pid) + .and_then(|path| { + WIN_CHANGE_CHANNEL_TX + .lock() + .unwrap() + .borrow() + .as_ref() + .unwrap() + .send(path) + .map_err(|err| err.into()) + }) + .unwrap_or_else(|err| { + println!( + "Error finding process with pid {} for hwnd: {:?}: {:?}", + pid, hwnd, err + ) + }); + } +} + +pub fn win_event_hook_loop() { + unsafe { + SetWinEventHook( + EVENT_SYSTEM_FOREGROUND, + EVENT_SYSTEM_FOREGROUND, + HINSTANCE::default(), + Some(win_event_proc), + 0, + 0, + WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS, + ); + SetWinEventHook( + EVENT_SYSTEM_MINIMIZEEND, + EVENT_SYSTEM_MINIMIZEEND, + HINSTANCE::default(), + Some(win_event_proc), + 0, + 0, + WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS, + ); + + let mut msg: MSG = MSG::default(); + while GetMessageW(&mut msg, HWND::default(), 0, 0).as_bool() { + TranslateMessage(&msg); + DispatchMessageW(&msg); + } + } +} diff --git a/todo.txt b/todo.txt new file mode 100644 index 0000000..e158405 --- /dev/null +++ b/todo.txt @@ -0,0 +1 @@ +- create session-focused API - should provide a callback for all existing sessions and any time a new sessions is created \ No newline at end of file