diff --git a/README.md b/README.md index 8a32665..9c11a0a 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,6 @@ use winrt::*; // import various helper types use winrt::windows::system::diagnostics::*; // import namespace Windows.System.Diagnostics fn main() { - let rt = RuntimeContext::init(); // initialize the Windows Runtime let infos = ProcessDiagnosticInfo::get_for_processes().unwrap().unwrap(); println!("Currently executed processes ({}):", infos.get_size().unwrap()); for p in &infos { diff --git a/examples/hexdump.rs b/examples/hexdump.rs index dc74251..7ec3058 100644 --- a/examples/hexdump.rs +++ b/examples/hexdump.rs @@ -10,16 +10,10 @@ use winrt::*; use winrt::windows::foundation::*; use winrt::windows::storage::*; -fn main() { - let rt = RuntimeContext::init(); - run(); - rt.uninit(); -} - const BYTES_PER_ROW: usize = 24; const CHUNK_SIZE: usize = 4096; -fn run() { +fn main() { // Use the current executable as source file (because we know that will exist). let exe_path = ::std::env::current_exe().expect("current_exe failed"); let exe_path_str = exe_path.to_str().expect("invalid unicode path"); diff --git a/examples/test.rs b/examples/test.rs index 58f2c02..93af6b8 100644 --- a/examples/test.rs +++ b/examples/test.rs @@ -11,12 +11,6 @@ use winrt::windows::devices::midi::*; use winrt::windows::storage::*; fn main() { - let rt = RuntimeContext::init(); - run(); - rt.uninit(); -} - -fn run() { let base = FastHString::new("https://github.com"); let relative = FastHString::new("contextfree/winrt-rust"); let uri = Uri::create_with_relative_uri(&base, &relative).unwrap(); @@ -44,7 +38,7 @@ fn run() { let res = DeviceInformation::find_all_async_aqs_filter(&wrong_deviceselector); if let Err(e) = res { println!("HRESULT (FindAllAsyncAqsFilter) = {:?}", e); - let mut error_info = { + let error_info = { let mut res = ptr::null_mut(); assert_eq!(GetRestrictedErrorInfo(&mut res), S_OK); ComPtr::wrap(res) diff --git a/examples/toast_notify.rs b/examples/toast_notify.rs index fda28d9..43f74c4 100644 --- a/examples/toast_notify.rs +++ b/examples/toast_notify.rs @@ -10,12 +10,6 @@ use winrt::windows::data::xml::dom::*; use winrt::windows::ui::notifications::*; fn main() { - let rt = RuntimeContext::init(); - run(); - rt.uninit(); -} - -fn run() { // Get a toast XML template let toast_xml = ToastNotificationManager::get_template_content(ToastTemplateType::ToastText02).unwrap().unwrap(); diff --git a/src/comptr.rs b/src/comptr.rs index 7b3e259..5efe449 100644 --- a/src/comptr.rs +++ b/src/comptr.rs @@ -95,7 +95,6 @@ impl ComPtr { /// use winrt::*; /// use winrt::windows::foundation::Uri; /// - /// # let rt = winrt::RuntimeContext::init(); /// let uri = FastHString::new("https://www.rust-lang.org"); /// let uri = Uri::create_uri(&uri).unwrap(); /// assert_eq!("Windows.Foundation.Uri", uri.get_runtime_class_name().to_string()); diff --git a/src/lib.rs b/src/lib.rs index 85adc93..f7d96a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,6 @@ //! use winrt::windows::system::diagnostics::*; // import namespace Windows.System.Diagnostics //! //! fn main() { -//! let rt = RuntimeContext::init(); // initialize the Windows Runtime //! let infos = ProcessDiagnosticInfo::get_for_processes().unwrap().unwrap(); //! println!("Currently executed processes ({}):", infos.get_size().unwrap()); //! for p in &infos { @@ -66,7 +65,8 @@ pub use cominterfaces::{ComInterface, ComIid, IUnknown, IRestrictedErrorInfo, IA mod rt; pub use rt::{RtInterface, RtClassInterface, RtNamedClass, RtValueType, RtType, RtActivatable, RtDefaultConstructible, IInspectable, IInspectableVtbl, IActivationFactory, - IMemoryBufferByteAccess, Char, RuntimeContext, IteratorAdaptor}; + IMemoryBufferByteAccess, Char, IteratorAdaptor, + ApartmentType, init_apartment, uninit_apartment}; pub use rt::async::{RtAsyncAction, RtAsyncOperation}; mod result; diff --git a/src/rt/mod.rs b/src/rt/mod.rs index ed391a6..8a0433c 100644 --- a/src/rt/mod.rs +++ b/src/rt/mod.rs @@ -1,4 +1,3 @@ -use std::marker::PhantomData; use std::ptr; use super::{ComInterface, HString, HStringReference, HStringArg, ComPtr, ComArray, ComIid, Guid}; @@ -9,7 +8,8 @@ use w::shared::winerror::{S_OK, S_FALSE, CO_E_NOTINITIALIZED, REGDB_E_CLASSNOTRE use w::shared::guiddef::IID; use w::um::unknwnbase::IUnknownVtbl; use w::winrt::hstring::HSTRING; -use w::winrt::roapi::{RO_INIT_MULTITHREADED, RoInitialize, RoUninitialize, RoGetActivationFactory}; +use w::winrt::roapi::{RO_INIT_MULTITHREADED, RO_INIT_SINGLETHREADED, RoInitialize, RoUninitialize, RoGetActivationFactory}; +use w::um::combaseapi::CoIncrementMTAUsage; use self::gen::windows::foundation::collections::{ IIterable, @@ -116,11 +116,14 @@ pub trait RtActivatable : RtNamedClass { fn get_activation_factory() -> ComPtr where Interface: RtInterface + ComIid { let mut res = ptr::null_mut(); let class_id = unsafe { HStringReference::from_utf16_unchecked(Self::name()) }; - let hr = unsafe { RoGetActivationFactory(class_id.get(), ::iid().as_ref(), &mut res as *mut *mut _ as *mut *mut VOID) }; + let mut hr = unsafe { RoGetActivationFactory(class_id.get(), ::iid().as_ref(), &mut res as *mut *mut _ as *mut *mut VOID) }; + if hr == CO_E_NOTINITIALIZED { + let mut cookie = ptr::null_mut(); + unsafe { CoIncrementMTAUsage(&mut cookie); } + hr = unsafe { RoGetActivationFactory(class_id.get(), ::iid().as_ref(), &mut res as *mut *mut _ as *mut *mut VOID) }; + } if hr == S_OK { unsafe { ComPtr::wrap(res) } - } else if hr == CO_E_NOTINITIALIZED { - panic!("WinRT is not initialized") } else if hr == REGDB_E_CLASSNOTREG { let name = Self::name(); panic!("WinRT class \"{}\" not registered", String::from_utf16_lossy(&name[0..name.len()-1])) @@ -779,35 +782,31 @@ impl IMemoryBufferByteAccess { } } - -/// Manages initialization and uninitialization of the Windows Runtime. -pub struct RuntimeContext { - token: PhantomData<*mut ()> // only allow construction from inside this module, and make it !Send/!Sync. +/// Determines the concurrency model used for incoming calls to the objects created by a thread +/// that was initialized with a given apartment type (see also `init_apartment`). +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[repr(u32)] +pub enum ApartmentType { + /// Initializes the thread in the multi-threaded apartment (MTA). + MTA = RO_INIT_MULTITHREADED, + + /// Initializes the thread as a single-threaded apartment (STA). + STA = RO_INIT_SINGLETHREADED } -impl RuntimeContext { - /// Initializes the Windows Runtime. This must be called before any other operations can use - /// the Windows Runtime. The Windows Runtime will be unitilized when the returned `RuntimeContext` - /// is dropped or `uninit` is called explicitly. You have to make sure that this does not happen - /// as long as any Windows Runtime object is still alive. - #[inline] - pub fn init() -> RuntimeContext { - let hr = unsafe { RoInitialize(RO_INIT_MULTITHREADED) }; - assert!(hr == S_OK || hr == S_FALSE, "failed to call RoInitialize: error {}", hr); - RuntimeContext { token: PhantomData } - } - - /// Unitializes the Windows Runtime. This must not be called as long as any Windows Runtime - /// object is still alive. - #[inline] - pub fn uninit(self) { - drop(self); - } +/// Initializes the current thread for use with the Windows Runtime. This is usually not needed, +/// because winrt-rust ensures that threads are implicitly assigned to the multi-threaded apartment (MTA). +/// However, if you need your thread to be initialized as a single-threaded apartment (STA), you can +/// call `init_apartment(ApartmentType::STA)`. Only call this when you own the thread! +pub fn init_apartment(apartment_type: ApartmentType) { + let hr = unsafe { RoInitialize(apartment_type as u32) }; + assert!(hr == S_OK || hr == S_FALSE, "failed to call RoInitialize: error {}", hr); } -impl Drop for RuntimeContext { - #[inline] - fn drop(&mut self) { - unsafe { RoUninitialize() }; - } +/// Uninitializes the Windows Runtime in the current thread. This is usually not +/// needed, because uninitialization happens automatically on process termination. +/// Make sure that you never call this from a thread that still has references to +/// Windows Runtime objects. +pub fn uninit_apartment() { + unsafe { RoUninitialize() }; }