Code cleanup part 2: Huge ass refactor edition

This commit is contained in:
Your Name
2023-02-19 18:14:37 -05:00
parent f2d5263244
commit 9eb15a4a5f
6 changed files with 346 additions and 246 deletions

View File

@@ -0,0 +1,56 @@
use windows::{
Win32::{
Media::Audio::{
IMMNotificationClient, IMMNotificationClient_Impl,
},
},
};
#[windows::core::implement(IMMNotificationClient)]
pub(crate) struct DeviceNotificationClient {
pub(crate) callback: Box<dyn Fn(&windows::core::PCWSTR)>
}
impl Drop for DeviceNotificationClient {
fn drop(&mut self) {
println!("DNC drop");
}
}
impl IMMNotificationClient_Impl for DeviceNotificationClient {
fn OnDeviceStateChanged(
&self,
_pwstrdeviceid: &windows::core::PCWSTR,
_dwnewstate: u32,
) -> windows::core::Result<()> {
Ok(())
}
fn OnDeviceAdded(&self, pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> {
(self.callback)(pwstrdeviceid);
Ok(())
}
fn OnDeviceRemoved(&self, _pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> {
println!("OnDeviceRemoved");
// TODO: Remove device and all its sessions
Ok(())
}
fn OnDefaultDeviceChanged(
&self,
_flow: windows::Win32::Media::Audio::EDataFlow,
_role: windows::Win32::Media::Audio::ERole,
_pwstrdefaultdeviceid: &windows::core::PCWSTR,
) -> windows::core::Result<()> {
Ok(())
}
fn OnPropertyValueChanged(
&self,
_pwstrdeviceid: &windows::core::PCWSTR,
_key: &windows::Win32::UI::Shell::PropertiesSystem::PROPERTYKEY,
) -> windows::core::Result<()> {
Ok(())
}
}

View File

@@ -1,7 +1,9 @@
mod device_notification_client;
mod pid_to_exe;
mod session_notification;
mod sm_session_notifier;
mod window_change;
use once_cell::sync::Lazy;
use std::{
collections::HashSet,
error::Error,
@@ -10,174 +12,25 @@ use std::{
io::{BufRead, BufReader},
path::Path,
ptr::null_mut,
sync::{Arc, Mutex},
sync::{Arc, Mutex, mpsc::{self, Sender, Receiver}}, thread,
};
use widestring::U16CStr;
use sm_session_notifier::SMSessionNotifier;
use window_change::{win_event_hook_loop, WIN_CHANGE_CHANNEL_TX};
use windows::{
core::{Interface, PCWSTR},
core::Interface,
Win32::{
Foundation::{HINSTANCE, HWND},
Media::Audio::{
eRender, IAudioSessionControl, IAudioSessionControl2, IAudioSessionManager2,
IAudioSessionNotification, IAudioSessionNotification_Impl, IMMDeviceEnumerator,
IMMNotificationClient, IMMNotificationClient_Impl, ISimpleAudioVolume,
MMDeviceEnumerator, DEVICE_STATE_ACTIVE,
},
System::Com::{CoCreateInstance, CoInitializeEx, CLSCTX_ALL, COINIT_MULTITHREADED, STGM_READ},
UI::{
Accessibility::{SetWinEventHook, HWINEVENTHOOK},
WindowsAndMessaging::{
DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId,
TranslateMessage, EVENT_SYSTEM_FOREGROUND, EVENT_SYSTEM_MINIMIZEEND, MSG,
WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS,
},
}, Devices::FunctionDiscovery::PKEY_Device_FriendlyName,
Media::Audio::{IAudioSessionControl2, ISimpleAudioVolume},
System::Com::{CoInitializeEx, COINIT_MULTITHREADED},
},
};
use crate::pid_to_exe::pid_to_exe_path;
unsafe extern "system" fn win_event_proc(
_hook: HWINEVENTHOOK,
event: u32,
_hwnd: HWND,
_id_object: i32,
_id_child: i32,
_dw_event_thread: u32,
_dwms_event_time: u32,
) {
if event == EVENT_SYSTEM_FOREGROUND || event == EVENT_SYSTEM_MINIMIZEEND {
let mut pid: u32 = 0;
// Instead of using the hwnd passed to us in the wineventproc, call GetForegroundWindow again because Hollow Knight does something weird
// and immediately fires again with the hwnd for explorer?
let hwnd = GetForegroundWindow();
GetWindowThreadProcessId(hwnd, Some(&mut pid));
pid_to_exe_path(pid)
.map(|path| MUTER.lock().unwrap().notify_window_changed(&path))
.unwrap_or_else(|err| {
println!(
"Error finding process with pid {} for hwnd: {:?}: {:?}",
pid, hwnd, err
)
});
}
}
fn win_event_hook_loop() {
unsafe {
SetWinEventHook(
EVENT_SYSTEM_FOREGROUND,
EVENT_SYSTEM_FOREGROUND,
HINSTANCE::default(),
Some(win_event_proc),
0,
0,
WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS,
);
SetWinEventHook(
EVENT_SYSTEM_MINIMIZEEND,
EVENT_SYSTEM_MINIMIZEEND,
HINSTANCE::default(),
Some(win_event_proc),
0,
0,
WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS,
);
let mut msg: MSG = MSG::default();
while GetMessageW(&mut msg, HWND::default(), 0, 0).as_bool() {
TranslateMessage(&msg);
DispatchMessageW(&msg);
}
}
}
#[windows::core::implement(IAudioSessionNotification)]
struct SessionNotification {}
impl IAudioSessionNotification_Impl for SessionNotification {
fn OnSessionCreated(
self: &SessionNotification,
newsession: &core::option::Option<IAudioSessionControl>,
) -> windows::core::Result<()> {
let ses: IAudioSessionControl2 = newsession.as_ref().unwrap().cast().unwrap();
MUTER
.lock()
.unwrap()
.add_session(ses)
.unwrap();
Ok(())
}
}
impl Drop for SessionNotification {
fn drop(&mut self) {
println!("SN drop");
}
}
#[windows::core::implement(IMMNotificationClient)]
struct DeviceNotificationClient {}
impl Drop for DeviceNotificationClient {
fn drop(&mut self) {
println!("DNC drop");
}
}
impl IMMNotificationClient_Impl for DeviceNotificationClient {
fn OnDeviceStateChanged(
&self,
_pwstrdeviceid: &windows::core::PCWSTR,
_dwnewstate: u32,
) -> windows::core::Result<()> {
Ok(())
}
fn OnDeviceAdded(&self, pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> {
MUTER
.lock()
.unwrap()
.add_device_by_id(pwstrdeviceid)
.unwrap_or_else(|error| {
println!("Error adding device: {:?}", error);
});
Ok(())
}
fn OnDeviceRemoved(&self, _pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> {
println!("OnDeviceRemoved");
// TODO: Remove device and all its sessions
Ok(())
}
fn OnDefaultDeviceChanged(
&self,
_flow: windows::Win32::Media::Audio::EDataFlow,
_role: windows::Win32::Media::Audio::ERole,
_pwstrdefaultdeviceid: &windows::core::PCWSTR,
) -> windows::core::Result<()> {
Ok(())
}
fn OnPropertyValueChanged(
&self,
_pwstrdeviceid: &windows::core::PCWSTR,
_key: &windows::Win32::UI::Shell::PropertiesSystem::PROPERTYKEY,
) -> windows::core::Result<()> {
Ok(())
}
}
struct SessionMuter {
sessions: Arc<Mutex<Vec<IAudioSessionControl2>>>,
device_enumerator: IMMDeviceEnumerator,
device_notification_client: IMMNotificationClient,
session_notification: IAudioSessionNotification,
sessions: Vec<IAudioSessionControl2>,
mute_executables: HashSet<String>,
mute_flag: bool,
session_managers: Vec<IAudioSessionManager2>,
}
fn load_mute_txt() -> HashSet<String> {
@@ -187,15 +40,8 @@ fn load_mute_txt() -> HashSet<String> {
impl SessionMuter {
fn new() -> SessionMuter {
let s = Arc::new(Mutex::new(Vec::new()));
SessionMuter {
session_managers: Vec::new(),
sessions: s.clone(),
device_enumerator: unsafe {
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL).unwrap()
},
device_notification_client: IMMNotificationClient::from(DeviceNotificationClient {}),
session_notification: IAudioSessionNotification::from(SessionNotification {}),
sessions: Vec::new(),
mute_executables: load_mute_txt(),
mute_flag: true,
}
@@ -212,28 +58,26 @@ impl SessionMuter {
}
match self.session_to_filename(&session) {
Ok(file_name) => {
let fn_str = file_name
.to_string_lossy()
.to_string();
let fn_str = file_name.to_string_lossy().to_string();
if self.mute_executables.contains(&fn_str) {
println!("Adding session from: {:?}", fn_str);
unsafe {
let volume: ISimpleAudioVolume = session.cast()?;
volume.SetMute(self.mute_flag, null_mut())?;
}
self.sessions.lock().unwrap().push(session);
self.sessions.push(session);
}
},
}
Err(err) => {
println!("Unable to get filename for session {:?}", err);
},
}
}
Ok(())
}
unsafe fn set_mute_all(self: &mut SessionMuter, mute: bool) {
let sessions = self.sessions.lock().unwrap();
let results = sessions
let results = self
.sessions
.iter()
.map(|session_control2| session_control2.cast::<ISimpleAudioVolume>())
.map(|vol_result| vol_result.map(|volume| volume.SetMute(mute, null_mut())));
@@ -242,45 +86,21 @@ impl SessionMuter {
}
}
unsafe fn notify_window_changed(self: &mut SessionMuter, path: &str) {
let binding = Path::new(path)
.file_name()
.expect("failed to extract filename from path");
let file_name = binding.to_os_string().to_string_lossy().to_string();
let mute_flag = !self.mute_executables.contains(&file_name);
if mute_flag != self.mute_flag {
self.mute_flag = mute_flag;
self.set_mute_all(self.mute_flag);
println!(
"Mute set to {} due to foreground window: {}",
self.mute_flag, file_name
);
}
}
fn boot_devices(self: &mut SessionMuter) -> Result<(), Box<dyn Error>> {
fn notify_window_changed(self: &mut SessionMuter, path: &str) {
unsafe {
self.device_enumerator
.RegisterEndpointNotificationCallback(&self.device_notification_client)?;
let device_collection = self
.device_enumerator
.EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE)?;
for device in (0..device_collection.GetCount()?).map(|x| device_collection.Item(x)) {
let mmdevice = device?;
self.add_device(mmdevice)?;
let binding = Path::new(path)
.file_name()
.expect("failed to extract filename from path");
let file_name = binding.to_os_string().to_string_lossy().to_string();
let mute_flag = !self.mute_executables.contains(&file_name);
if mute_flag != self.mute_flag {
self.mute_flag = mute_flag;
self.set_mute_all(self.mute_flag);
println!(
"Mute set to {} due to foreground window: {}",
self.mute_flag, file_name
);
}
println!("All devices initialized.");
return Ok(());
}
}
fn add_device_by_id(self: &mut SessionMuter, id: &PCWSTR) -> Result<(), Box<dyn Error>> {
unsafe {
let device_enumerator: IMMDeviceEnumerator =
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)?;
let device = device_enumerator.GetDevice(id)?;
self.add_device(device)?;
Ok(())
}
}
@@ -297,45 +117,35 @@ impl SessionMuter {
return Ok(file_name.to_os_string());
}
}
unsafe fn add_device(
self: &mut SessionMuter,
device: windows::Win32::Media::Audio::IMMDevice,
) -> Result<(), Box<dyn Error>> {
let sm: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?;
sm.RegisterSessionNotification(&self.session_notification)?;
let session_enumerator = sm.GetSessionEnumerator()?;
let session_count = session_enumerator.GetCount()?;
let device_sessions = (0..session_count)
.map(|idx| {
session_enumerator
.GetSession(idx)
.and_then(|session| session.cast())
})
.collect::<Result<Vec<IAudioSessionControl2>, _>>()?;
for session in device_sessions {
self.add_session(session)?;
}
let prop_store = device.OpenPropertyStore(STGM_READ)?;
// this definition isn't in windows-rs yet apparently :(
let prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?;
let str = U16CStr::from_ptr_str(prop_var.Anonymous.Anonymous.Anonymous.pwszVal.0).to_string_lossy();
println!("Device Added: {} Existing Sessions: {}", str, session_count);
self.session_managers.push(sm);
Ok(())
}
}
unsafe impl Send for SessionMuter {}
unsafe impl Sync for SessionMuter {}
static MUTER: Lazy<Arc<Mutex<SessionMuter>>> =
Lazy::new(|| Arc::new(Mutex::new(SessionMuter::new())));
fn main() {
unsafe {
CoInitializeEx(None, COINIT_MULTITHREADED).unwrap();
}
MUTER.lock().unwrap().boot_devices().expect("failed to initialize devices");
win_event_hook_loop();
let muter = Arc::new(Mutex::new(SessionMuter::new()));
let muter2 = muter.clone();
let sn = SMSessionNotifier::new(Box::new(move |session| {
muter2.lock().unwrap().add_session(session).unwrap()
}));
sn.lock()
.unwrap()
.as_mut()
.unwrap()
.boot_devices()
.expect("failed to get initial audio devices and sessions");
let (tx, rx) : (Sender<String>, Receiver<String>) = mpsc::channel();
*WIN_CHANGE_CHANNEL_TX.lock().unwrap().borrow_mut() = Some(tx);
thread::spawn(move || {
win_event_hook_loop();
});
loop {
let path = rx.recv().unwrap();
muter.lock().unwrap().notify_window_changed(&path);
}
}

View File

@@ -0,0 +1,31 @@
use windows::{
core::{Interface},
Win32::{
Media::Audio::{
IAudioSessionControl, IAudioSessionControl2,
IAudioSessionNotification, IAudioSessionNotification_Impl,
},
},
};
#[windows::core::implement(IAudioSessionNotification)]
pub(crate) struct SessionNotification {
pub(crate) callback: Box<dyn Fn(IAudioSessionControl2)>
}
impl IAudioSessionNotification_Impl for SessionNotification {
fn OnSessionCreated(
self: &SessionNotification,
newsession: &core::option::Option<IAudioSessionControl>,
) -> windows::core::Result<()> {
let ses: IAudioSessionControl2 = newsession.as_ref().unwrap().cast().unwrap();
(self.callback)(ses);
Ok(())
}
}
impl Drop for SessionNotification {
fn drop(&mut self) {
println!("SN drop");
}
}

116
src/sm_session_notifier.rs Normal file
View File

@@ -0,0 +1,116 @@
use std::{
error::Error,
sync::{Arc, Mutex},
};
use crate::{
device_notification_client::DeviceNotificationClient, session_notification::SessionNotification,
};
use widestring::U16CStr;
use windows::{
core::{Interface, PCWSTR},
Win32::{
Devices::FunctionDiscovery::PKEY_Device_FriendlyName,
Media::Audio::{
eRender, IAudioSessionControl2, IAudioSessionManager2, IAudioSessionNotification,
IMMDeviceEnumerator, IMMNotificationClient, MMDeviceEnumerator, DEVICE_STATE_ACTIVE,
},
System::Com::{CoCreateInstance, CLSCTX_ALL, STGM_READ},
},
};
pub(crate) struct SMSessionNotifier {
device_enumerator: IMMDeviceEnumerator,
device_notification_client: IMMNotificationClient,
session_notification: IAudioSessionNotification,
session_managers: Vec<IAudioSessionManager2>,
notification_function: Box<dyn Fn(IAudioSessionControl2)>,
}
impl SMSessionNotifier {
pub(crate) fn new(
callback: Box<dyn Fn(IAudioSessionControl2)>,
) -> Arc<Mutex<Option<SMSessionNotifier>>> {
let session_notifier_arc: Arc<Mutex<Option<SMSessionNotifier>>> =
Arc::new(Mutex::new(None));
let sn2 = session_notifier_arc.clone();
let sn3 = session_notifier_arc.clone();
let notifier = SMSessionNotifier {
session_managers: Vec::new(),
device_enumerator: unsafe {
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL).unwrap()
},
device_notification_client: IMMNotificationClient::from(DeviceNotificationClient {
callback: Box::new(move |id| {
match sn2.lock().unwrap().as_mut().unwrap().add_device_by_id(id) {
Ok(_) => {},
Err(e) => {
println!("Failed to add new device: {:?}", e);
},
}
}),
}),
session_notification: IAudioSessionNotification::from(SessionNotification {
callback: Box::new(move |session| {
(sn3.lock().unwrap().as_mut().unwrap().notification_function)(session)
}),
}),
notification_function: callback,
};
*session_notifier_arc.lock().unwrap() = Some(notifier);
return session_notifier_arc.clone();
}
pub(crate) fn boot_devices(self: &mut SMSessionNotifier) -> Result<(), Box<dyn Error>> {
unsafe {
self.device_enumerator
.RegisterEndpointNotificationCallback(&self.device_notification_client)?;
let device_collection = self
.device_enumerator
.EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE)?;
for device in (0..device_collection.GetCount()?).map(|x| device_collection.Item(x)) {
let mmdevice = device?;
self.add_device(mmdevice)?;
}
println!("All devices initialized.");
return Ok(());
}
}
fn add_device_by_id(self: &mut SMSessionNotifier, id: &PCWSTR) -> Result<(), Box<dyn Error>> {
unsafe {
let device_enumerator: IMMDeviceEnumerator =
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)?;
let device = device_enumerator.GetDevice(id)?;
self.add_device(device)?;
Ok(())
}
}
unsafe fn add_device(
self: &mut SMSessionNotifier,
device: windows::Win32::Media::Audio::IMMDevice,
) -> Result<(), Box<dyn Error>> {
let sm: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?;
sm.RegisterSessionNotification(&self.session_notification)?;
let session_enumerator = sm.GetSessionEnumerator()?;
let session_count = session_enumerator.GetCount()?;
let device_sessions = (0..session_count)
.map(|idx| {
session_enumerator
.GetSession(idx)
.and_then(|session| session.cast())
})
.collect::<Result<Vec<IAudioSessionControl2>, _>>()?;
for session in device_sessions {
(self.notification_function)(session);
}
let prop_store = device.OpenPropertyStore(STGM_READ)?;
// this definition isn't in windows-rs yet apparently :(
let prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?;
let str = U16CStr::from_ptr_str(prop_var.Anonymous.Anonymous.Anonymous.pwszVal.0)
.to_string_lossy();
println!("Device Added: {} Existing Sessions: {}", str, session_count);
self.session_managers.push(sm);
Ok(())
}
}

86
src/window_change.rs Normal file
View File

@@ -0,0 +1,86 @@
use std::{
cell::RefCell,
sync::{mpsc::Sender, Arc, Mutex},
};
use once_cell::sync::Lazy;
use windows::Win32::{
Foundation::{HINSTANCE, HWND},
UI::{
Accessibility::{SetWinEventHook, HWINEVENTHOOK},
WindowsAndMessaging::{
DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId,
TranslateMessage, EVENT_SYSTEM_FOREGROUND, EVENT_SYSTEM_MINIMIZEEND, MSG,
WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS,
},
},
};
use crate::pid_to_exe::pid_to_exe_path;
pub static WIN_CHANGE_CHANNEL_TX: Lazy<Arc<Mutex<RefCell<Option<Sender<String>>>>>> =
Lazy::new(|| Arc::new(Mutex::new(RefCell::new(None))));
unsafe extern "system" fn win_event_proc(
_hook: HWINEVENTHOOK,
event: u32,
_hwnd: HWND,
_id_object: i32,
_id_child: i32,
_dw_event_thread: u32,
_dwms_event_time: u32,
) {
if event == EVENT_SYSTEM_FOREGROUND || event == EVENT_SYSTEM_MINIMIZEEND {
let mut pid: u32 = 0;
// Instead of using the hwnd passed to us in the wineventproc, call GetForegroundWindow again because Hollow Knight does something weird
// and immediately fires again with the hwnd for explorer?
let hwnd = GetForegroundWindow();
GetWindowThreadProcessId(hwnd, Some(&mut pid));
pid_to_exe_path(pid)
.and_then(|path| {
WIN_CHANGE_CHANNEL_TX
.lock()
.unwrap()
.borrow()
.as_ref()
.unwrap()
.send(path)
.map_err(|err| err.into())
})
.unwrap_or_else(|err| {
println!(
"Error finding process with pid {} for hwnd: {:?}: {:?}",
pid, hwnd, err
)
});
}
}
pub fn win_event_hook_loop() {
unsafe {
SetWinEventHook(
EVENT_SYSTEM_FOREGROUND,
EVENT_SYSTEM_FOREGROUND,
HINSTANCE::default(),
Some(win_event_proc),
0,
0,
WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS,
);
SetWinEventHook(
EVENT_SYSTEM_MINIMIZEEND,
EVENT_SYSTEM_MINIMIZEEND,
HINSTANCE::default(),
Some(win_event_proc),
0,
0,
WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS,
);
let mut msg: MSG = MSG::default();
while GetMessageW(&mut msg, HWND::default(), 0, 0).as_bool() {
TranslateMessage(&msg);
DispatchMessageW(&msg);
}
}
}

1
todo.txt Normal file
View File

@@ -0,0 +1 @@
- create session-focused API - should provide a callback for all existing sessions and any time a new sessions is created