Update to current winapi, split project

This commit is contained in:
Your Name
2024-08-25 22:24:06 -04:00
parent a647fc5186
commit 82153d31c9
18 changed files with 593 additions and 145 deletions

View File

@@ -0,0 +1,28 @@
[package]
name = "auto_mute_lib"
version = "0.2.0"
edition = "2021"
[dependencies]
windows-core = "0.58.0"
[dependencies.windows]
version = "0.58.0"
features = [
"implement",
"Win32_Media_Audio",
"Win32_UI_Shell_PropertiesSystem",
"Data_Xml_Dom",
"Win32_UI",
"Win32_UI_Accessibility",
"Win32_Foundation",
"Win32_Security",
"Win32_System_Threading",
"Win32_UI_WindowsAndMessaging",
"Win32_System_Com",
"Win32_System_Com_StructuredStorage",
"Win32_Devices_FunctionDiscovery",
"Win32_Storage_FileSystem",
"Win32_System_WindowsProgramming",
"Win32_System_Variant"
]

View File

@@ -0,0 +1,45 @@
use windows::Win32::Media::Audio::{IMMNotificationClient, IMMNotificationClient_Impl, DEVICE_STATE};
pub trait DeviceNotificationObserver {
fn add_device(&self, device_id: &windows::core::PCWSTR);
}
#[windows::core::implement(IMMNotificationClient)]
pub(crate) struct DeviceNotificationClient {
pub observer: Box<dyn DeviceNotificationObserver>,
}
impl IMMNotificationClient_Impl for DeviceNotificationClient_Impl {
fn OnDeviceStateChanged(
&self,
_pwstrdeviceid: &windows::core::PCWSTR,
_dwnewstate: DEVICE_STATE,
) -> windows::core::Result<()> {
Ok(())
}
fn OnDeviceAdded(&self, pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> {
self.observer.add_device(pwstrdeviceid);
Ok(())
}
fn OnDeviceRemoved(&self, _pwstrdeviceid: &windows::core::PCWSTR) -> windows::core::Result<()> {
Ok(())
}
fn OnDefaultDeviceChanged(
&self,
_flow: windows::Win32::Media::Audio::EDataFlow,
_role: windows::Win32::Media::Audio::ERole,
_pwstrdefaultdeviceid: &windows::core::PCWSTR,
) -> windows::core::Result<()> {
Ok(())
}
fn OnPropertyValueChanged(
&self,
_pwstrdeviceid: &windows::core::PCWSTR,
_key: &windows::Win32::UI::Shell::PropertiesSystem::PROPERTYKEY,
) -> windows::core::Result<()> {
Ok(())
}
}

View File

@@ -0,0 +1,6 @@
mod device_notification_client;
mod pid_to_exe;
mod session_notification;
mod sm_session_notifier;
mod window_change;
pub mod muter;

View File

@@ -0,0 +1,163 @@
use std::{
collections::HashSet,
error::Error,
ffi::OsString,
path::Path,
ptr::null_mut,
sync::mpsc::{self, Receiver, Sender},
thread::{self, JoinHandle},
};
use crate::sm_session_notifier::SMSessionNotifierThread;
use crate::window_change::WindowChangeMonitor;
use windows::{
core::Interface,
Win32::Media::Audio::{IAudioSessionControl2, ISimpleAudioVolume},
};
use crate::pid_to_exe::pid_to_exe_path;
enum MuterMessage {
WindowChange(String),
AddSession(IAudioSessionControl2),
Exit(),
}
unsafe impl Send for MuterMessage {}
struct SessionMuter {
sessions: Vec<IAudioSessionControl2>,
mute_executables: HashSet<String>,
mute_flag: bool,
_session_notifier: SMSessionNotifierThread,
_win_change_mon: WindowChangeMonitor,
rx: Receiver<MuterMessage>,
}
impl SessionMuter {
fn new(
mute_executables: HashSet<String>,
rx: Receiver<MuterMessage>,
tx: Sender<MuterMessage>,
) -> SessionMuter {
SessionMuter {
sessions: Vec::new(),
mute_executables,
mute_flag: true,
_session_notifier: {
let tx = tx.clone();
SMSessionNotifierThread::new(Box::new(move |session| {
tx.send(MuterMessage::AddSession(session)).unwrap();
}))
},
_win_change_mon: {
WindowChangeMonitor::start(Box::new(move |s| {
tx.send(MuterMessage::WindowChange(s.to_owned())).unwrap();
}))
},
rx,
}
}
fn run(&mut self) {
loop {
let msg = self.rx.recv().unwrap();
match msg {
MuterMessage::WindowChange(win) => self.notify_window_changed(&win),
MuterMessage::AddSession(session) => self.add_session(session).unwrap(),
MuterMessage::Exit() => break,
}
}
}
fn add_session(
self: &mut SessionMuter,
session: IAudioSessionControl2,
) -> Result<(), Box<dyn Error>> {
if let Ok(file_name) = self.session_to_filename(&session) {
let fn_str = file_name.to_string_lossy().to_string();
if self.mute_executables.contains(&fn_str) {
println!("Adding session from: {:?}", fn_str);
unsafe {
let volume: ISimpleAudioVolume = session.cast()?;
volume.SetMute(self.mute_flag, null_mut())?;
}
self.sessions.push(session);
}
}
Ok(())
}
fn set_mute_all(self: &mut SessionMuter, mute: bool) {
unsafe {
let results = self
.sessions
.iter()
.map(|session_control2| session_control2.cast::<ISimpleAudioVolume>())
.map(|vol_result| vol_result?.SetMute(mute, null_mut()));
for err in results.filter_map(|x| x.err()) {
println!("Error muting a session: {:?}", err);
}
}
}
fn notify_window_changed(self: &mut SessionMuter, path: &str) {
let binding = Path::new(path)
.file_name()
.expect("failed to extract filename from path");
let file_name = binding.to_os_string().to_string_lossy().to_string();
let mute_flag = !self.mute_executables.contains(&file_name);
if mute_flag != self.mute_flag {
self.mute_flag = mute_flag;
self.set_mute_all(self.mute_flag);
println!(
"Mute set to {} due to foreground window: {}",
self.mute_flag, file_name
);
}
}
fn session_to_filename(
self: &mut SessionMuter,
session: &IAudioSessionControl2,
) -> Result<OsString, Box<dyn Error>> {
unsafe {
let pid = session.GetProcessId()?;
let path = pid_to_exe_path(pid)?;
let file_name = Path::new(&path)
.file_name()
.ok_or("Failed to extract filename from path")?;
Ok(file_name.to_os_string())
}
}
}
unsafe impl Send for SessionMuter {}
pub struct MuterThread {
handle: Option<JoinHandle<()>>,
sender: Sender<MuterMessage>,
}
impl MuterThread {
pub fn new(s: HashSet<String>) -> MuterThread {
let (sender, receiver) = mpsc::channel::<MuterMessage>();
MuterThread {
sender: sender.clone(),
handle: Some(thread::spawn(move || {
let mut muter = SessionMuter::new(s, receiver, sender);
muter.run();
})),
}
}
}
impl Drop for MuterThread {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
self.sender.send(MuterMessage::Exit()).unwrap();
handle.join().unwrap()
}
}
}

View File

@@ -0,0 +1,29 @@
use std::error::Error;
use windows::Win32::{
Foundation::{CloseHandle, MAX_PATH},
System::Threading::{
OpenProcess, QueryFullProcessImageNameW, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ,
},
};
pub fn pid_to_exe_path(pid: u32) -> Result<String, Box<dyn Error>> {
let mut exe_name: Vec<u16> = Vec::with_capacity(MAX_PATH as usize);
let mut size: u32 = exe_name.capacity().try_into().unwrap();
unsafe {
let process = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, false, pid)?;
QueryFullProcessImageNameW(
process,
Default::default(),
windows::core::PWSTR(exe_name.as_mut_ptr()),
&mut size,
)
.ok();
CloseHandle(process)?;
exe_name.set_len(size.try_into().unwrap());
}
let process_name = String::from_utf16_lossy(&exe_name);
Ok(process_name)
}

View File

@@ -0,0 +1,23 @@
use windows::{
core::{implement, Interface},
Win32::Media::Audio::{
IAudioSessionControl, IAudioSessionControl2, IAudioSessionNotification, IAudioSessionNotification_Impl,
},
};
#[implement(IAudioSessionNotification)]
pub(crate) struct SessionNotification {
pub(crate) observer: Box<dyn SessionObserver>,
}
pub trait SessionObserver {
fn add_session(&self, session: IAudioSessionControl2);
}
impl IAudioSessionNotification_Impl for SessionNotification_Impl {
fn OnSessionCreated(&self,newsession:Option<&IAudioSessionControl>) -> windows::core::Result<()> {
let ses: IAudioSessionControl2 = newsession.as_ref().unwrap().cast().unwrap();
self.observer.add_session(ses);
Ok(())
}
}

View File

@@ -0,0 +1,205 @@
use std::{
error::Error,
sync::mpsc::{self, Receiver, Sender},
thread::{self, JoinHandle},
};
use crate::{
device_notification_client::{DeviceNotificationClient, DeviceNotificationObserver},
session_notification::{SessionNotification, SessionObserver},
};
use windows::{
core::{Interface, PCWSTR},
Win32::{
Devices::FunctionDiscovery::PKEY_Device_FriendlyName,
Media::Audio::{
eRender, IAudioSessionControl2, IAudioSessionManager2, IAudioSessionNotification,
IMMDeviceEnumerator, IMMNotificationClient, MMDeviceEnumerator, DEVICE_STATE_ACTIVE, IMMDevice,
},
System::Com::{
CoCreateInstance, StructuredStorage::PropVariantClear, CLSCTX_ALL, STGM_READ,
},
},
};
pub enum SMMessage {
Session(IAudioSessionControl2),
Device(PCWSTR),
Exit(),
}
unsafe impl Send for SMMessage {}
struct SessionToMessage {
sender: Sender<SMMessage>,
}
impl SessionObserver for SessionToMessage {
fn add_session(&self, session: IAudioSessionControl2) {
self.sender
.send(SMMessage::Session(session))
.unwrap_or_else(|_| println!("Failed to add new session"));
}
}
struct DeviceToMessage {
sender: Sender<SMMessage>,
}
impl DeviceNotificationObserver for DeviceToMessage {
fn add_device(&self, device_id: &windows::core::PCWSTR) {
self.sender
.send(SMMessage::Device(*device_id))
.unwrap_or_else(|_| println!("Failed to add new device"));
}
}
pub(crate) struct SMSessionNotifier {
device_enumerator: IMMDeviceEnumerator,
device_notification_client: IMMNotificationClient,
session_notification: IAudioSessionNotification,
session_managers: Vec<IAudioSessionManager2>,
notification_function: Box<dyn Fn(IAudioSessionControl2)>,
receiver: Receiver<SMMessage>,
}
impl SMSessionNotifier {
pub(crate) fn new(
callback: Box<dyn Fn(IAudioSessionControl2)>,
sender: mpsc::Sender<SMMessage>,
receiver: mpsc::Receiver<SMMessage>,
) -> SMSessionNotifier {
SMSessionNotifier {
session_managers: Vec::new(),
device_enumerator: unsafe {
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL).unwrap()
},
device_notification_client: IMMNotificationClient::from(DeviceNotificationClient {
observer: Box::new(DeviceToMessage {
sender: sender.clone(),
}),
}),
session_notification: IAudioSessionNotification::from(SessionNotification {
observer: Box::new(SessionToMessage { sender }),
}),
notification_function: callback,
receiver,
}
}
pub(crate) fn boot_devices(self: &mut SMSessionNotifier) -> Result<(), Box<dyn Error>> {
unsafe {
self.device_enumerator
.RegisterEndpointNotificationCallback(&self.device_notification_client)?;
let device_collection = self
.device_enumerator
.EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE)?;
for device in (0..device_collection.GetCount()?).map(|x| device_collection.Item(x)) {
let mmdevice = device?;
self.add_device(mmdevice)?;
}
Ok(())
}
}
fn add_device_by_id(self: &mut SMSessionNotifier, id: &PCWSTR) -> Result<(), Box<dyn Error>> {
unsafe {
let device_enumerator: IMMDeviceEnumerator =
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)?;
let device = device_enumerator.GetDevice(*id)?;
self.add_device(device)?;
Ok(())
}
}
unsafe fn add_device(
self: &mut SMSessionNotifier,
device: IMMDevice,
) -> Result<(), Box<dyn Error>> {
let session_manager: IAudioSessionManager2 = device.Activate(CLSCTX_ALL, None)?;
session_manager.RegisterSessionNotification(&self.session_notification)?;
let session_enumerator = session_manager.GetSessionEnumerator()?;
let session_count = session_enumerator.GetCount()?;
let device_sessions = (0..session_count)
.map(|idx| {
session_enumerator
.GetSession(idx)
.and_then(|session| session.cast())
})
.collect::<Result<Vec<IAudioSessionControl2>, _>>()?;
for session in device_sessions {
(self.notification_function)(session);
}
let prop_store = device.OpenPropertyStore(STGM_READ)?;
let mut prop_var = prop_store.GetValue(&PKEY_Device_FriendlyName)?;
println!(
"Device Added: {} Existing Sessions: {}",
prop_var.to_string(),
session_count
);
PropVariantClear(&mut prop_var)?;
self.session_managers.push(session_manager);
Ok(())
}
pub fn run(&mut self) -> Result<(), Box<dyn Error>> {
self.boot_devices()?;
loop {
let msg = self.receiver.recv()?;
match msg {
SMMessage::Session(session) => {
(self.notification_function)(session);
}
SMMessage::Device(id) => {
self.add_device_by_id(&id)?;
}
SMMessage::Exit() => break,
};
}
Ok(())
}
}
impl Drop for SMSessionNotifier {
fn drop(&mut self) {
unsafe {
self.session_managers
.drain(..)
.map(|x| x.UnregisterSessionNotification(&self.session_notification))
.collect::<Result<(), _>>()
.unwrap();
self.device_enumerator
.UnregisterEndpointNotificationCallback(&self.device_notification_client)
.unwrap();
}
}
}
pub struct SMSessionNotifierThread {
handle: Option<JoinHandle<()>>,
sender: Sender<SMMessage>,
}
impl SMSessionNotifierThread {
pub fn new(s: Box<dyn Fn(IAudioSessionControl2) + Send>) -> SMSessionNotifierThread {
let (sender, receiver) = mpsc::channel::<SMMessage>();
SMSessionNotifierThread {
sender: sender.clone(),
handle: Some(thread::spawn(move || {
let mut session_notifier = SMSessionNotifier::new(s, sender, receiver);
session_notifier.run().unwrap();
})),
}
}
}
impl Drop for SMSessionNotifierThread {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
self.sender.send(SMMessage::Exit()).unwrap();
handle.join().unwrap()
}
}
}

View File

@@ -0,0 +1,127 @@
use std::{
sync::{atomic::AtomicU32, Arc, Mutex},
thread::{self, JoinHandle},
};
use windows::Win32::{
Foundation::{HINSTANCE, HWND, LPARAM, WPARAM},
System::Threading::GetCurrentThreadId,
UI::{
Accessibility::{SetWinEventHook, UnhookWinEvent, HWINEVENTHOOK},
WindowsAndMessaging::{
DispatchMessageW, GetForegroundWindow, GetMessageW, GetWindowThreadProcessId,
PostThreadMessageW, TranslateMessage, EVENT_SYSTEM_FOREGROUND,
EVENT_SYSTEM_MINIMIZEEND, MSG, WINEVENT_OUTOFCONTEXT, WINEVENT_SKIPOWNPROCESS, WM_QUIT,
},
},
};
use crate::pid_to_exe::pid_to_exe_path;
type WinCallback = Box<dyn Fn(&str) + Send>;
static WIN_CHANGE_CALLBACK: Mutex<Option<WinCallback>> = Mutex::new(None);
unsafe extern "system" fn win_event_proc(
_hook: HWINEVENTHOOK,
event: u32,
_hwnd: HWND,
_id_object: i32,
_id_child: i32,
_dw_event_thread: u32,
_dwms_event_time: u32,
) {
if event == EVENT_SYSTEM_FOREGROUND || event == EVENT_SYSTEM_MINIMIZEEND {
let mut pid: u32 = 0;
// Instead of using the hwnd passed to us in the wineventproc, call GetForegroundWindow again because Hollow Knight does something weird
// and immediately fires again with the hwnd for explorer?
let hwnd = GetForegroundWindow();
GetWindowThreadProcessId(hwnd, Some(&mut pid));
pid_to_exe_path(pid)
.map(|path| WIN_CHANGE_CALLBACK.lock().unwrap().as_ref().unwrap()(&path))
.unwrap_or_else(|err| {
println!(
"Error finding process with pid {} for hwnd: {:?}: {:?}",
pid, hwnd, err
)
});
}
}
pub fn await_win_change_events(callback: WinCallback) {
*WIN_CHANGE_CALLBACK.lock().unwrap() = Some(callback);
unsafe {
let fg_event = SetWinEventHook(
EVENT_SYSTEM_FOREGROUND,
EVENT_SYSTEM_FOREGROUND,
HINSTANCE::default(),
Some(win_event_proc),
0,
0,
WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS,
);
let min_event = SetWinEventHook(
EVENT_SYSTEM_MINIMIZEEND,
EVENT_SYSTEM_MINIMIZEEND,
HINSTANCE::default(),
Some(win_event_proc),
0,
0,
WINEVENT_OUTOFCONTEXT | WINEVENT_SKIPOWNPROCESS,
);
let mut msg: MSG = MSG::default();
while GetMessageW(&mut msg, HWND::default(), 0, 0).0 > 0 {
TranslateMessage(&msg).unwrap();
DispatchMessageW(&msg);
}
UnhookWinEvent(fg_event).unwrap();
UnhookWinEvent(min_event).unwrap();
}
}
pub struct WindowChangeMonitor {
join_handle: Option<JoinHandle<()>>,
win_thread_id: Arc<AtomicU32>,
}
impl Drop for WindowChangeMonitor {
fn drop(&mut self) {
if let Some(join_handle) = self.join_handle.take() {
let tid: u32 = self
.win_thread_id
.load(std::sync::atomic::Ordering::Relaxed);
if tid != 0 {
unsafe {
PostThreadMessageW(tid, WM_QUIT, WPARAM(0), LPARAM(0))
.ok()
.unwrap();
}
}
join_handle
.join()
.expect("Unable to terminate window change thread");
}
}
}
impl WindowChangeMonitor {
pub(crate) fn start(f: WinCallback) -> WindowChangeMonitor {
let win_thread_id = Arc::new(AtomicU32::new(0));
let join_handle = {
let win_thread_id = win_thread_id.clone();
thread::spawn(move || {
win_thread_id.store(
unsafe { GetCurrentThreadId() },
std::sync::atomic::Ordering::Relaxed,
);
await_win_change_events(f);
})
};
WindowChangeMonitor {
join_handle: Some(join_handle),
win_thread_id,
}
}
}