Free some additional structures from winapi properly

This commit is contained in:
Your Name
2023-02-23 22:47:33 -05:00
parent 847631d94e
commit a647fc5186
6 changed files with 122 additions and 94 deletions

View File

@@ -1,10 +1,4 @@
use windows::{ use windows::Win32::Media::Audio::{IMMNotificationClient, IMMNotificationClient_Impl};
Win32::{
Media::Audio::{
IMMNotificationClient, IMMNotificationClient_Impl,
},
},
};
pub trait DeviceNotificationObserver { pub trait DeviceNotificationObserver {
fn add_device(&self, device_id: &windows::core::PCWSTR); fn add_device(&self, device_id: &windows::core::PCWSTR);
@@ -12,7 +6,7 @@ pub trait DeviceNotificationObserver {
#[windows::core::implement(IMMNotificationClient)] #[windows::core::implement(IMMNotificationClient)]
pub(crate) struct DeviceNotificationClient { pub(crate) struct DeviceNotificationClient {
pub observer: Box<dyn DeviceNotificationObserver> pub observer: Box<dyn DeviceNotificationObserver>,
} }
impl IMMNotificationClient_Impl for DeviceNotificationClient { impl IMMNotificationClient_Impl for DeviceNotificationClient {

View File

@@ -6,9 +6,10 @@ mod window_change;
use std::{ use std::{
collections::HashSet, collections::HashSet,
env,
error::Error, error::Error,
ffi::OsString, ffi::OsString,
fs::{File, self}, fs::{self, File},
io::{BufRead, BufReader}, io::{BufRead, BufReader},
path::Path, path::Path,
ptr::null_mut, ptr::null_mut,
@@ -19,11 +20,21 @@ use std::{
use sm_session_notifier::SMSessionNotifierThread; use sm_session_notifier::SMSessionNotifierThread;
use window_change::WindowChangeMonitor; use window_change::WindowChangeMonitor;
use windows::{ use windows::{
core::{Interface}, core::Interface,
w,
Win32::{ Win32::{
Foundation::HANDLE,
Media::Audio::{IAudioSessionControl2, ISimpleAudioVolume}, Media::Audio::{IAudioSessionControl2, ISimpleAudioVolume},
System::{Com::{CoInitializeEx, COINIT_MULTITHREADED}, Threading::WaitForSingleObject, WindowsProgramming::INFINITE}, Storage::FileSystem::{FindFirstChangeNotificationW, FILE_NOTIFY_CHANGE_LAST_WRITE, FindCloseChangeNotification, FindNextChangeNotification}, Foundation::HANDLE, Storage::FileSystem::{
}, w, FindCloseChangeNotification, FindFirstChangeNotificationW, FindNextChangeNotification,
FILE_NOTIFY_CHANGE_LAST_WRITE,
},
System::{
Com::{CoInitializeEx, COINIT_MULTITHREADED},
Threading::WaitForSingleObject,
WindowsProgramming::INFINITE,
},
},
}; };
use crate::pid_to_exe::pid_to_exe_path; use crate::pid_to_exe::pid_to_exe_path;
@@ -40,8 +51,8 @@ struct SessionMuter {
sessions: Vec<IAudioSessionControl2>, sessions: Vec<IAudioSessionControl2>,
mute_executables: HashSet<String>, mute_executables: HashSet<String>,
mute_flag: bool, mute_flag: bool,
_sn: SMSessionNotifierThread, _session_notifier: SMSessionNotifierThread,
_wn: WindowChangeMonitor, _win_change_mon: WindowChangeMonitor,
rx: Receiver<MuterMessage>, rx: Receiver<MuterMessage>,
} }
@@ -55,13 +66,13 @@ impl SessionMuter {
sessions: Vec::new(), sessions: Vec::new(),
mute_executables, mute_executables,
mute_flag: true, mute_flag: true,
_sn: { _session_notifier: {
let tx = tx.clone(); let tx = tx.clone();
SMSessionNotifierThread::new(Box::new(move |session| { SMSessionNotifierThread::new(Box::new(move |session| {
tx.send(MuterMessage::AddSession(session)).unwrap(); tx.send(MuterMessage::AddSession(session)).unwrap();
})) }))
}, },
_wn: { _win_change_mon: {
WindowChangeMonitor::start(Box::new(move |s| { WindowChangeMonitor::start(Box::new(move |s| {
tx.send(MuterMessage::WindowChange(s.to_owned())).unwrap(); tx.send(MuterMessage::WindowChange(s.to_owned())).unwrap();
})) }))
@@ -99,7 +110,8 @@ impl SessionMuter {
Ok(()) Ok(())
} }
unsafe fn set_mute_all(self: &mut SessionMuter, mute: bool) { fn set_mute_all(self: &mut SessionMuter, mute: bool) {
unsafe {
let results = self let results = self
.sessions .sessions
.iter() .iter()
@@ -109,9 +121,9 @@ impl SessionMuter {
println!("Error muting a session: {:?}", err); println!("Error muting a session: {:?}", err);
} }
} }
}
fn notify_window_changed(self: &mut SessionMuter, path: &str) { fn notify_window_changed(self: &mut SessionMuter, path: &str) {
unsafe {
let binding = Path::new(path) let binding = Path::new(path)
.file_name() .file_name()
.expect("failed to extract filename from path"); .expect("failed to extract filename from path");
@@ -126,7 +138,6 @@ impl SessionMuter {
); );
} }
} }
}
fn session_to_filename( fn session_to_filename(
self: &mut SessionMuter, self: &mut SessionMuter,
@@ -195,9 +206,9 @@ fn main() {
unsafe { unsafe {
CoInitializeEx(None, COINIT_MULTITHREADED).unwrap(); CoInitializeEx(None, COINIT_MULTITHREADED).unwrap();
} }
let mute_file = "mute.txt"; let mute_file: String = env::args().nth(1).unwrap_or("mute.txt".to_string());
loop { loop {
let mut _mt = MuterThread::new(load_mute_txt(mute_file)); let mut _mt = MuterThread::new(load_mute_txt(&mute_file));
await_file_change(mute_file).unwrap(); await_file_change(&mute_file).unwrap();
} }
} }

View File

@@ -1,26 +1,29 @@
use std::error::Error; use std::error::Error;
use windows::Win32::{ use windows::Win32::{
Foundation::MAX_PATH, Foundation::{CloseHandle, MAX_PATH},
System::Threading::{ System::Threading::{
OpenProcess, QueryFullProcessImageNameW, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ, OpenProcess, QueryFullProcessImageNameW, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ,
}, },
}; };
pub unsafe fn pid_to_exe_path(pid: u32) -> Result<String, Box<dyn Error>> { pub fn pid_to_exe_path(pid: u32) -> Result<String, Box<dyn Error>> {
let mut exe_name: Vec<u16> = Vec::with_capacity(MAX_PATH as usize); let mut exe_name: Vec<u16> = Vec::with_capacity(MAX_PATH as usize);
let mut size: u32 = exe_name.capacity().try_into().unwrap(); let mut size: u32 = exe_name.capacity().try_into().unwrap();
let process = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, false, pid); unsafe {
if !QueryFullProcessImageNameW( let process = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, false, pid)?;
process?, QueryFullProcessImageNameW(
process,
Default::default(), Default::default(),
windows::core::PWSTR(exe_name.as_mut_ptr()), windows::core::PWSTR(exe_name.as_mut_ptr()),
&mut size, &mut size,
).as_bool() )
{ .ok()?;
return Err(Box::new(windows::core::Error::from_win32()));
} CloseHandle(process);
exe_name.set_len(size.try_into().unwrap()); exe_name.set_len(size.try_into().unwrap());
}
let process_name = String::from_utf16_lossy(&exe_name); let process_name = String::from_utf16_lossy(&exe_name);
Ok(process_name) Ok(process_name)
} }

View File

@@ -1,16 +1,14 @@
use windows::{ use windows::{
core::{Interface}, core::Interface,
Win32::{ Win32::Media::Audio::{
Media::Audio::{ IAudioSessionControl, IAudioSessionControl2, IAudioSessionNotification,
IAudioSessionControl, IAudioSessionControl2, IAudioSessionNotification_Impl,
IAudioSessionNotification, IAudioSessionNotification_Impl,
},
}, },
}; };
#[windows::core::implement(IAudioSessionNotification)] #[windows::core::implement(IAudioSessionNotification)]
pub(crate) struct SessionNotification { pub(crate) struct SessionNotification {
pub(crate) observer: Box<dyn SessionObserver> pub(crate) observer: Box<dyn SessionObserver>,
} }
pub trait SessionObserver { pub trait SessionObserver {

View File

@@ -16,7 +16,9 @@ use windows::{
eRender, IAudioSessionControl2, IAudioSessionManager2, IAudioSessionNotification, eRender, IAudioSessionControl2, IAudioSessionManager2, IAudioSessionNotification,
IMMDeviceEnumerator, IMMNotificationClient, MMDeviceEnumerator, DEVICE_STATE_ACTIVE, IMMDeviceEnumerator, IMMNotificationClient, MMDeviceEnumerator, DEVICE_STATE_ACTIVE,
}, },
System::Com::{CoCreateInstance, CLSCTX_ALL, STGM_READ}, System::Com::{
CoCreateInstance, StructuredStorage::PropVariantClear, CLSCTX_ALL, STGM_READ,
},
}, },
}; };
@@ -78,9 +80,7 @@ impl SMSessionNotifier {
}), }),
}), }),
session_notification: IAudioSessionNotification::from(SessionNotification { session_notification: IAudioSessionNotification::from(SessionNotification {
observer: Box::new(SessionToMessage { observer: Box::new(SessionToMessage { sender }),
sender
}),
}), }),
notification_function: callback, notification_function: callback,
receiver, receiver,
@@ -98,7 +98,6 @@ impl SMSessionNotifier {
let mmdevice = device?; let mmdevice = device?;
self.add_device(mmdevice)?; self.add_device(mmdevice)?;
} }
println!("All devices initialized.");
Ok(()) Ok(())
} }
} }
@@ -117,9 +116,9 @@ impl SMSessionNotifier {
self: &mut SMSessionNotifier, self: &mut SMSessionNotifier,
device: windows::Win32::Media::Audio::IMMDevice, device: windows::Win32::Media::Audio::IMMDevice,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let sm: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?; let session_manager: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?;
sm.RegisterSessionNotification(&self.session_notification)?; session_manager.RegisterSessionNotification(&self.session_notification)?;
let session_enumerator = sm.GetSessionEnumerator()?; let session_enumerator = session_manager.GetSessionEnumerator()?;
let session_count = session_enumerator.GetCount()?; let session_count = session_enumerator.GetCount()?;
let device_sessions = (0..session_count) let device_sessions = (0..session_count)
.map(|idx| { .map(|idx| {
@@ -132,13 +131,15 @@ impl SMSessionNotifier {
(self.notification_function)(session); (self.notification_function)(session);
} }
let prop_store = device.OpenPropertyStore(STGM_READ)?; let prop_store = device.OpenPropertyStore(STGM_READ)?;
let prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?; let mut prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?;
println!( println!(
"Device Added: {} Existing Sessions: {}", "Device Added: {} Existing Sessions: {}",
prop_var.Anonymous.Anonymous.Anonymous.pwszVal.to_string()?, prop_var.Anonymous.Anonymous.Anonymous.pwszVal.to_string()?,
session_count session_count
); );
self.session_managers.push(sm); PropVariantClear(&mut prop_var)?;
self.session_managers.push(session_manager);
Ok(()) Ok(())
} }
@@ -161,6 +162,21 @@ impl SMSessionNotifier {
} }
} }
impl Drop for SMSessionNotifier {
fn drop(&mut self) {
unsafe {
self.session_managers
.drain(..)
.map(|x| x.UnregisterSessionNotification(&self.session_notification))
.collect::<Result<(), _>>()
.unwrap();
self.device_enumerator
.UnregisterEndpointNotificationCallback(&self.device_notification_client)
.unwrap();
}
}
}
pub struct SMSessionNotifierThread { pub struct SMSessionNotifierThread {
handle: Option<JoinHandle<()>>, handle: Option<JoinHandle<()>>,
sender: Sender<SMMessage>, sender: Sender<SMMessage>,

View File

@@ -1,15 +1,19 @@
use std::{sync::{Mutex, atomic::AtomicU32, Arc}, thread::{self, JoinHandle}}; use std::{
sync::{atomic::AtomicU32, Arc, Mutex},
thread::{self, JoinHandle},
};
use windows::Win32::{ use windows::Win32::{
Foundation::{HINSTANCE, HWND, WPARAM, LPARAM}, Foundation::{HINSTANCE, HWND, LPARAM, WPARAM},
System::Threading::GetCurrentThreadId,
UI::{ UI::{
Accessibility::{SetWinEventHook, HWINEVENTHOOK, UnhookWinEvent}, Accessibility::{SetWinEventHook, UnhookWinEvent, HWINEVENTHOOK},
WindowsAndMessaging::{ WindowsAndMessaging::{
DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId, DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId,
TranslateMessage, EVENT_SYSTEM_FOREGROUND, EVENT_SYSTEM_MINIMIZEEND, MSG, PostThreadMessageW, TranslateMessage, EVENT_SYSTEM_FOREGROUND,
WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS, PostThreadMessageW, WM_QUIT, EVENT_SYSTEM_MINIMIZEEND, MSG, WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS, WM_QUIT,
},
}, },
}, System::Threading::GetCurrentThreadId,
}; };
use crate::pid_to_exe::pid_to_exe_path; use crate::pid_to_exe::pid_to_exe_path;
@@ -33,13 +37,7 @@ unsafe extern "system" fn win_event_proc(
let hwnd = GetForegroundWindow(); let hwnd = GetForegroundWindow();
GetWindowThreadProcessId(hwnd, Some(&mut pid)); GetWindowThreadProcessId(hwnd, Some(&mut pid));
pid_to_exe_path(pid) pid_to_exe_path(pid)
.map(|path| .map(|path| WIN_CHANGE_CALLBACK.lock().unwrap().as_ref().unwrap()(&path))
WIN_CHANGE_CALLBACK
.lock()
.unwrap()
.as_ref()
.unwrap()
(&path))
.unwrap_or_else(|err| { .unwrap_or_else(|err| {
println!( println!(
"Error finding process with pid {} for hwnd: {:?}: {:?}", "Error finding process with pid {} for hwnd: {:?}: {:?}",
@@ -49,7 +47,6 @@ unsafe extern "system" fn win_event_proc(
} }
} }
pub fn await_win_change_events(callback: WinCallback) { pub fn await_win_change_events(callback: WinCallback) {
*WIN_CHANGE_CALLBACK.lock().unwrap() = Some(callback); *WIN_CHANGE_CALLBACK.lock().unwrap() = Some(callback);
unsafe { unsafe {
@@ -84,21 +81,27 @@ pub fn await_win_change_events(callback: WinCallback) {
pub struct WindowChangeMonitor { pub struct WindowChangeMonitor {
join_handle: Option<JoinHandle<()>>, join_handle: Option<JoinHandle<()>>,
win_thread_id: Arc<AtomicU32> win_thread_id: Arc<AtomicU32>,
} }
impl Drop for WindowChangeMonitor { impl Drop for WindowChangeMonitor {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(join_handle) = self.join_handle.take() { if let Some(join_handle) = self.join_handle.take() {
let tid : u32 = self.win_thread_id.load(std::sync::atomic::Ordering::Relaxed); let tid: u32 = self
.win_thread_id
.load(std::sync::atomic::Ordering::Relaxed);
if tid != 0 { if tid != 0 {
unsafe { unsafe {
PostThreadMessageW(tid, WM_QUIT, WPARAM(0), LPARAM(0)).ok().unwrap(); PostThreadMessageW(tid, WM_QUIT, WPARAM(0), LPARAM(0))
.ok()
.unwrap();
} }
} }
join_handle.join().expect("Unable to terminate window change thread"); join_handle
.join()
.expect("Unable to terminate window change thread");
} }
} }
} }
@@ -109,13 +112,16 @@ impl WindowChangeMonitor {
let join_handle = { let join_handle = {
let win_thread_id = win_thread_id.clone(); let win_thread_id = win_thread_id.clone();
thread::spawn(move || { thread::spawn(move || {
win_thread_id.store(unsafe { GetCurrentThreadId() }, std::sync::atomic::Ordering::Relaxed); win_thread_id.store(
unsafe { GetCurrentThreadId() },
std::sync::atomic::Ordering::Relaxed,
);
await_win_change_events(f); await_win_change_events(f);
}) })
}; };
WindowChangeMonitor { WindowChangeMonitor {
join_handle: Some(join_handle), join_handle: Some(join_handle),
win_thread_id win_thread_id,
} }
} }
} }