From a647fc51863e5997db87c988d5ed58b1fda7bc4d Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 23 Feb 2023 22:47:33 -0500 Subject: [PATCH] Free some additional structures from winapi properly --- src/device_notification_client.rs | 12 ++--- src/main.rs | 77 ++++++++++++++++++------------- src/pid_to_exe.rs | 27 ++++++----- src/session_notification.rs | 12 ++--- src/sm_session_notifier.rs | 36 +++++++++++---- src/window_change.rs | 52 ++++++++++++--------- 6 files changed, 122 insertions(+), 94 deletions(-) diff --git a/src/device_notification_client.rs b/src/device_notification_client.rs index 4b65e5d..b379ea8 100644 --- a/src/device_notification_client.rs +++ b/src/device_notification_client.rs @@ -1,18 +1,12 @@ -use windows::{ - Win32::{ - Media::Audio::{ - IMMNotificationClient, IMMNotificationClient_Impl, - }, - }, -}; +use windows::Win32::Media::Audio::{IMMNotificationClient, IMMNotificationClient_Impl}; pub trait DeviceNotificationObserver { - fn add_device(&self, device_id : &windows::core::PCWSTR); + fn add_device(&self, device_id: &windows::core::PCWSTR); } #[windows::core::implement(IMMNotificationClient)] pub(crate) struct DeviceNotificationClient { - pub observer: Box + pub observer: Box, } impl IMMNotificationClient_Impl for DeviceNotificationClient { diff --git a/src/main.rs b/src/main.rs index bd0314b..e06ea69 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,9 +6,10 @@ mod window_change; use std::{ collections::HashSet, + env, error::Error, ffi::OsString, - fs::{File, self}, + fs::{self, File}, io::{BufRead, BufReader}, path::Path, ptr::null_mut, @@ -19,11 +20,21 @@ use std::{ use sm_session_notifier::SMSessionNotifierThread; use window_change::WindowChangeMonitor; use windows::{ - core::{Interface}, + core::Interface, + w, Win32::{ + Foundation::HANDLE, Media::Audio::{IAudioSessionControl2, ISimpleAudioVolume}, - System::{Com::{CoInitializeEx, COINIT_MULTITHREADED}, Threading::WaitForSingleObject, WindowsProgramming::INFINITE}, Storage::FileSystem::{FindFirstChangeNotificationW, FILE_NOTIFY_CHANGE_LAST_WRITE, FindCloseChangeNotification, FindNextChangeNotification}, Foundation::HANDLE, - }, w, + Storage::FileSystem::{ + FindCloseChangeNotification, FindFirstChangeNotificationW, FindNextChangeNotification, + FILE_NOTIFY_CHANGE_LAST_WRITE, + }, + System::{ + Com::{CoInitializeEx, COINIT_MULTITHREADED}, + Threading::WaitForSingleObject, + WindowsProgramming::INFINITE, + }, + }, }; use crate::pid_to_exe::pid_to_exe_path; @@ -40,8 +51,8 @@ struct SessionMuter { sessions: Vec, mute_executables: HashSet, mute_flag: bool, - _sn: SMSessionNotifierThread, - _wn: WindowChangeMonitor, + _session_notifier: SMSessionNotifierThread, + _win_change_mon: WindowChangeMonitor, rx: Receiver, } @@ -55,13 +66,13 @@ impl SessionMuter { sessions: Vec::new(), mute_executables, mute_flag: true, - _sn: { + _session_notifier: { let tx = tx.clone(); SMSessionNotifierThread::new(Box::new(move |session| { tx.send(MuterMessage::AddSession(session)).unwrap(); })) }, - _wn: { + _win_change_mon: { WindowChangeMonitor::start(Box::new(move |s| { tx.send(MuterMessage::WindowChange(s.to_owned())).unwrap(); })) @@ -99,32 +110,32 @@ impl SessionMuter { Ok(()) } - unsafe fn set_mute_all(self: &mut SessionMuter, mute: bool) { - let results = self - .sessions - .iter() - .map(|session_control2| session_control2.cast::()) - .map(|vol_result| vol_result.map(|volume| volume.SetMute(mute, null_mut()))); - for err in results.filter_map(|x| x.err()) { - println!("Error muting a session: {:?}", err); + fn set_mute_all(self: &mut SessionMuter, mute: bool) { + unsafe { + let results = self + .sessions + .iter() + .map(|session_control2| session_control2.cast::()) + .map(|vol_result| vol_result.map(|volume| volume.SetMute(mute, null_mut()))); + for err in results.filter_map(|x| x.err()) { + println!("Error muting a session: {:?}", err); + } } } fn notify_window_changed(self: &mut SessionMuter, path: &str) { - unsafe { - let binding = Path::new(path) - .file_name() - .expect("failed to extract filename from path"); - let file_name = binding.to_os_string().to_string_lossy().to_string(); - let mute_flag = !self.mute_executables.contains(&file_name); - if mute_flag != self.mute_flag { - self.mute_flag = mute_flag; - self.set_mute_all(self.mute_flag); - println!( - "Mute set to {} due to foreground window: {}", - self.mute_flag, file_name - ); - } + let binding = Path::new(path) + .file_name() + .expect("failed to extract filename from path"); + let file_name = binding.to_os_string().to_string_lossy().to_string(); + let mute_flag = !self.mute_executables.contains(&file_name); + if mute_flag != self.mute_flag { + self.mute_flag = mute_flag; + self.set_mute_all(self.mute_flag); + println!( + "Mute set to {} due to foreground window: {}", + self.mute_flag, file_name + ); } } @@ -195,9 +206,9 @@ fn main() { unsafe { CoInitializeEx(None, COINIT_MULTITHREADED).unwrap(); } - let mute_file = "mute.txt"; + let mute_file: String = env::args().nth(1).unwrap_or("mute.txt".to_string()); loop { - let mut _mt = MuterThread::new(load_mute_txt(mute_file)); - await_file_change(mute_file).unwrap(); + let mut _mt = MuterThread::new(load_mute_txt(&mute_file)); + await_file_change(&mute_file).unwrap(); } } diff --git a/src/pid_to_exe.rs b/src/pid_to_exe.rs index 3118eb9..b5dc398 100644 --- a/src/pid_to_exe.rs +++ b/src/pid_to_exe.rs @@ -1,26 +1,29 @@ use std::error::Error; use windows::Win32::{ - Foundation::MAX_PATH, + Foundation::{CloseHandle, MAX_PATH}, System::Threading::{ OpenProcess, QueryFullProcessImageNameW, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ, }, }; -pub unsafe fn pid_to_exe_path(pid: u32) -> Result> { +pub fn pid_to_exe_path(pid: u32) -> Result> { let mut exe_name: Vec = Vec::with_capacity(MAX_PATH as usize); let mut size: u32 = exe_name.capacity().try_into().unwrap(); - let process = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, false, pid); - if !QueryFullProcessImageNameW( - process?, - Default::default(), - windows::core::PWSTR(exe_name.as_mut_ptr()), - &mut size, - ).as_bool() - { - return Err(Box::new(windows::core::Error::from_win32())); + unsafe { + let process = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, false, pid)?; + QueryFullProcessImageNameW( + process, + Default::default(), + windows::core::PWSTR(exe_name.as_mut_ptr()), + &mut size, + ) + .ok()?; + + CloseHandle(process); + + exe_name.set_len(size.try_into().unwrap()); } - exe_name.set_len(size.try_into().unwrap()); let process_name = String::from_utf16_lossy(&exe_name); Ok(process_name) } diff --git a/src/session_notification.rs b/src/session_notification.rs index b56c391..473c559 100644 --- a/src/session_notification.rs +++ b/src/session_notification.rs @@ -1,16 +1,14 @@ use windows::{ - core::{Interface}, - Win32::{ - Media::Audio::{ - IAudioSessionControl, IAudioSessionControl2, - IAudioSessionNotification, IAudioSessionNotification_Impl, - }, + core::Interface, + Win32::Media::Audio::{ + IAudioSessionControl, IAudioSessionControl2, IAudioSessionNotification, + IAudioSessionNotification_Impl, }, }; #[windows::core::implement(IAudioSessionNotification)] pub(crate) struct SessionNotification { - pub(crate) observer: Box + pub(crate) observer: Box, } pub trait SessionObserver { diff --git a/src/sm_session_notifier.rs b/src/sm_session_notifier.rs index 22f7a5d..a7ede57 100644 --- a/src/sm_session_notifier.rs +++ b/src/sm_session_notifier.rs @@ -16,7 +16,9 @@ use windows::{ eRender, IAudioSessionControl2, IAudioSessionManager2, IAudioSessionNotification, IMMDeviceEnumerator, IMMNotificationClient, MMDeviceEnumerator, DEVICE_STATE_ACTIVE, }, - System::Com::{CoCreateInstance, CLSCTX_ALL, STGM_READ}, + System::Com::{ + CoCreateInstance, StructuredStorage::PropVariantClear, CLSCTX_ALL, STGM_READ, + }, }, }; @@ -78,9 +80,7 @@ impl SMSessionNotifier { }), }), session_notification: IAudioSessionNotification::from(SessionNotification { - observer: Box::new(SessionToMessage { - sender - }), + observer: Box::new(SessionToMessage { sender }), }), notification_function: callback, receiver, @@ -98,7 +98,6 @@ impl SMSessionNotifier { let mmdevice = device?; self.add_device(mmdevice)?; } - println!("All devices initialized."); Ok(()) } } @@ -117,9 +116,9 @@ impl SMSessionNotifier { self: &mut SMSessionNotifier, device: windows::Win32::Media::Audio::IMMDevice, ) -> Result<(), Box> { - let sm: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?; - sm.RegisterSessionNotification(&self.session_notification)?; - let session_enumerator = sm.GetSessionEnumerator()?; + let session_manager: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?; + session_manager.RegisterSessionNotification(&self.session_notification)?; + let session_enumerator = session_manager.GetSessionEnumerator()?; let session_count = session_enumerator.GetCount()?; let device_sessions = (0..session_count) .map(|idx| { @@ -132,13 +131,15 @@ impl SMSessionNotifier { (self.notification_function)(session); } let prop_store = device.OpenPropertyStore(STGM_READ)?; - let prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?; + let mut prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?; println!( "Device Added: {} Existing Sessions: {}", prop_var.Anonymous.Anonymous.Anonymous.pwszVal.to_string()?, session_count ); - self.session_managers.push(sm); + PropVariantClear(&mut prop_var)?; + + self.session_managers.push(session_manager); Ok(()) } @@ -161,6 +162,21 @@ impl SMSessionNotifier { } } +impl Drop for SMSessionNotifier { + fn drop(&mut self) { + unsafe { + self.session_managers + .drain(..) + .map(|x| x.UnregisterSessionNotification(&self.session_notification)) + .collect::>() + .unwrap(); + self.device_enumerator + .UnregisterEndpointNotificationCallback(&self.device_notification_client) + .unwrap(); + } + } +} + pub struct SMSessionNotifierThread { handle: Option>, sender: Sender, diff --git a/src/window_change.rs b/src/window_change.rs index 4f4d3d8..ac9bf15 100644 --- a/src/window_change.rs +++ b/src/window_change.rs @@ -1,15 +1,19 @@ -use std::{sync::{Mutex, atomic::AtomicU32, Arc}, thread::{self, JoinHandle}}; +use std::{ + sync::{atomic::AtomicU32, Arc, Mutex}, + thread::{self, JoinHandle}, +}; use windows::Win32::{ - Foundation::{HINSTANCE, HWND, WPARAM, LPARAM}, + Foundation::{HINSTANCE, HWND, LPARAM, WPARAM}, + System::Threading::GetCurrentThreadId, UI::{ - Accessibility::{SetWinEventHook, HWINEVENTHOOK, UnhookWinEvent}, + Accessibility::{SetWinEventHook, UnhookWinEvent, HWINEVENTHOOK}, WindowsAndMessaging::{ DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId, - TranslateMessage, EVENT_SYSTEM_FOREGROUND, EVENT_SYSTEM_MINIMIZEEND, MSG, - WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS, PostThreadMessageW, WM_QUIT, + PostThreadMessageW, TranslateMessage, EVENT_SYSTEM_FOREGROUND, + EVENT_SYSTEM_MINIMIZEEND, MSG, WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS, WM_QUIT, }, - }, System::Threading::GetCurrentThreadId, + }, }; use crate::pid_to_exe::pid_to_exe_path; @@ -33,13 +37,7 @@ unsafe extern "system" fn win_event_proc( let hwnd = GetForegroundWindow(); GetWindowThreadProcessId(hwnd, Some(&mut pid)); pid_to_exe_path(pid) - .map(|path| - WIN_CHANGE_CALLBACK - .lock() - .unwrap() - .as_ref() - .unwrap() - (&path)) + .map(|path| WIN_CHANGE_CALLBACK.lock().unwrap().as_ref().unwrap()(&path)) .unwrap_or_else(|err| { println!( "Error finding process with pid {} for hwnd: {:?}: {:?}", @@ -49,7 +47,6 @@ unsafe extern "system" fn win_event_proc( } } - pub fn await_win_change_events(callback: WinCallback) { *WIN_CHANGE_CALLBACK.lock().unwrap() = Some(callback); unsafe { @@ -84,38 +81,47 @@ pub fn await_win_change_events(callback: WinCallback) { pub struct WindowChangeMonitor { join_handle: Option>, - win_thread_id: Arc + win_thread_id: Arc, } impl Drop for WindowChangeMonitor { fn drop(&mut self) { if let Some(join_handle) = self.join_handle.take() { - let tid : u32 = self.win_thread_id.load(std::sync::atomic::Ordering::Relaxed); + let tid: u32 = self + .win_thread_id + .load(std::sync::atomic::Ordering::Relaxed); if tid != 0 { unsafe { - PostThreadMessageW(tid, WM_QUIT, WPARAM(0), LPARAM(0)).ok().unwrap(); + PostThreadMessageW(tid, WM_QUIT, WPARAM(0), LPARAM(0)) + .ok() + .unwrap(); } } - - join_handle.join().expect("Unable to terminate window change thread"); + + join_handle + .join() + .expect("Unable to terminate window change thread"); } } } impl WindowChangeMonitor { - pub(crate) fn start(f : WinCallback) -> WindowChangeMonitor { + pub(crate) fn start(f: WinCallback) -> WindowChangeMonitor { let win_thread_id = Arc::new(AtomicU32::new(0)); let join_handle = { let win_thread_id = win_thread_id.clone(); thread::spawn(move || { - win_thread_id.store(unsafe { GetCurrentThreadId() }, std::sync::atomic::Ordering::Relaxed); + win_thread_id.store( + unsafe { GetCurrentThreadId() }, + std::sync::atomic::Ordering::Relaxed, + ); await_win_change_events(f); }) }; WindowChangeMonitor { join_handle: Some(join_handle), - win_thread_id + win_thread_id, } } -} \ No newline at end of file +}