diff --git a/src/coreclr/src/interop/comwrappers.cpp b/src/coreclr/src/interop/comwrappers.cpp index a7d4db7769802..f39472778cdb9 100644 --- a/src/coreclr/src/interop/comwrappers.cpp +++ b/src/coreclr/src/interop/comwrappers.cpp @@ -223,12 +223,10 @@ namespace namespace { const int32_t TrackerRefShift = 32; - const ULONGLONG TrackerRefCounter = ULONGLONG{ 1 } << TrackerRefShift; - const ULONGLONG ComRefCounter = ULONGLONG{ 1 }; - const ULONGLONG TrackerRefZero = 0x0000000080000000; + const ULONGLONG TrackerRefCounter = ULONGLONG{ 1 } << TrackerRefShift; + const ULONGLONG DestroySentinel = 0x0000000080000000; const ULONGLONG TrackerRefCountMask = 0xffffffff00000000; const ULONGLONG ComRefCountMask = 0x000000007fffffff; - const ULONGLONG RefCountMask = 0xffffffff7fffffff; constexpr ULONG GetTrackerCount(_In_ ULONGLONG c) { @@ -419,11 +417,29 @@ HRESULT ManagedObjectWrapper::Create( void ManagedObjectWrapper::Destroy(_In_ ManagedObjectWrapper* wrapper) { _ASSERTE(wrapper != nullptr); + _ASSERTE(GetComCount(wrapper->_refCount) == 0); - // Manually trigger the destructor since placement - // new was used to allocate the object. - wrapper->~ManagedObjectWrapper(); - InteropLibImports::MemFree(wrapper, AllocScenario::ManagedObjectWrapper); + // Attempt to set the destroyed bit. + LONGLONG refCount; + LONGLONG prev; + do + { + prev = wrapper->_refCount; + refCount = prev | DestroySentinel; + } while (::InterlockedCompareExchange64(&wrapper->_refCount, refCount, prev) != prev); + + // The destroy sentinel represents the bit that indicates the wrapper + // should be destroyed. Since the reference count field (64-bit) holds + // two counters we rely on the singular sentinal value - no other bits + // in the 64-bit counter are set. If there are outstanding bits set it + // indicates there are still outstanding references. + if (refCount == DestroySentinel) + { + // Manually trigger the destructor since placement + // new was used to allocate the object. + wrapper->~ManagedObjectWrapper(); + InteropLibImports::MemFree(wrapper, AllocScenario::ManagedObjectWrapper); + } } ManagedObjectWrapper::ManagedObjectWrapper( @@ -449,48 +465,9 @@ ManagedObjectWrapper::ManagedObjectWrapper( ManagedObjectWrapper::~ManagedObjectWrapper() { - // If the target isn't null, then a managed object - // is going to leak. - _ASSERTE(Target == nullptr); -} - -ULONGLONG ManagedObjectWrapper::UniversalRelease(_In_ ULONGLONG dec) -{ - OBJECTHANDLE local = Target; - - LONGLONG refCount; - if (dec == ComRefCounter) - { - _ASSERTE(dec == 1); - refCount = ::InterlockedDecrement64(&_refCount); - } - else - { - _ASSERTE(dec == TrackerRefCounter); - LONGLONG prev; - do - { - prev = _refCount; - refCount = prev - dec; - } while (::InterlockedCompareExchange64(&_refCount, refCount, prev) != prev); - } - - // It is possible that a target wasn't set during an - // attempt to reactive the wrapper. - if ((RefCountMask & refCount) == 0 && local != nullptr) - { - _ASSERTE(!IsSet(CreateComInterfaceFlagsEx::IsPegged)); - _ASSERTE(refCount == TrackerRefZero || refCount == 0); - - // Attempt to reset the target if its current value is the same. - // It is possible the wrapper is in the middle of being reactivated. - (void)TrySetObjectHandle(nullptr, local); - - // Tell the runtime to delete the managed object instance handle. - InteropLibImports::DeleteObjectInstanceHandle(local); - } - - return refCount; + // If the target isn't null, then release it. + if (Target != nullptr) + InteropLibImports::DeleteObjectInstanceHandle(Target); } void* ManagedObjectWrapper::AsRuntimeDefined(_In_ REFIID riid) @@ -551,16 +528,18 @@ void ManagedObjectWrapper::ResetFlag(_In_ CreateComInterfaceFlagsEx flag) ::InterlockedAnd((LONG*)&_flags, resetMask); } -ULONG ManagedObjectWrapper::IsActiveAddRef() +bool ManagedObjectWrapper::IsRooted() const { - ULONG count = GetComCount(::InterlockedIncrement64(&_refCount)); - if (count == 1) + bool rooted = GetComCount(_refCount) > 0; + if (!rooted) { - // Ensure the current target is null. - ::InterlockedExchangePointer(&Target, nullptr); + // Only consider tracker ref count to be a "strong" ref count if it is pegged and alive. + rooted = (GetTrackerCount(_refCount) > 0) + && (IsSet(CreateComInterfaceFlagsEx::IsPegged) + || InteropLibImports::GetGlobalPeggingState()); } - return count; + return rooted; } ULONG ManagedObjectWrapper::AddRefFromReferenceTracker() @@ -578,7 +557,29 @@ ULONG ManagedObjectWrapper::AddRefFromReferenceTracker() ULONG ManagedObjectWrapper::ReleaseFromReferenceTracker() { - return GetTrackerCount(UniversalRelease(TrackerRefCounter)); + if (GetTrackerCount(_refCount) == 0) + { + _ASSERTE(!"Over release of MOW - ReferenceTracker"); + return (ULONG)-1; + } + + LONGLONG refCount; + LONGLONG prev; + do + { + prev = _refCount; + refCount = prev - TrackerRefCounter; + } while (::InterlockedCompareExchange64(&_refCount, refCount, prev) != prev); + + // If we observe the destroy sentinel, then this release + // must destroy the wrapper. + if (refCount == DestroySentinel) + { + _ASSERTE(!IsSet(CreateComInterfaceFlagsEx::IsPegged)); + Destroy(this); + } + + return GetTrackerCount(refCount); } HRESULT ManagedObjectWrapper::Peg() @@ -652,12 +653,20 @@ HRESULT ManagedObjectWrapper::QueryInterface( ULONG ManagedObjectWrapper::AddRef(void) { + _ASSERTE((_refCount & DestroySentinel) == 0); return GetComCount(::InterlockedIncrement64(&_refCount)); } ULONG ManagedObjectWrapper::Release(void) { - return GetComCount(UniversalRelease(ComRefCounter)); + _ASSERTE((_refCount & DestroySentinel) == 0); + if (GetComCount(_refCount) == 0) + { + _ASSERTE(!"Over release of MOW - COM"); + return (ULONG)-1; + } + + return GetComCount(::InterlockedDecrement64(&_refCount)); } namespace @@ -684,12 +693,19 @@ NativeObjectWrapperContext* NativeObjectWrapperContext::MapFromRuntimeContext(_I HRESULT NativeObjectWrapperContext::Create( _In_ IUnknown* external, + _In_opt_ IUnknown* inner, _In_ InteropLib::Com::CreateObjectFlags flags, _In_ size_t runtimeContextSize, _Outptr_ NativeObjectWrapperContext** context) { _ASSERTE(external != nullptr && context != nullptr); + // Aggregated inners are only currently supported for Aggregated + // scenarios involving IReferenceTracker. + _ASSERTE(inner == nullptr + || ((flags & InteropLib::Com::CreateObjectFlags_TrackerObject) + && (flags & InteropLib::Com::CreateObjectFlags_Aggregated))); + HRESULT hr; ComHolder trackerObject; @@ -710,7 +726,7 @@ HRESULT NativeObjectWrapperContext::Create( // Contract specifically requires zeroing out runtime context. ::memset(runtimeContext, 0, runtimeContextSize); - NativeObjectWrapperContext* contextLocal = new (cxtMem) NativeObjectWrapperContext{ runtimeContext, trackerObject }; + NativeObjectWrapperContext* contextLocal = new (cxtMem) NativeObjectWrapperContext{ runtimeContext, trackerObject, inner }; if (trackerObject != nullptr) { @@ -722,6 +738,13 @@ HRESULT NativeObjectWrapperContext::Create( Destroy(contextLocal); return hr; } + + // Aggregation with a tracker object must be "cleaned up". + if (flags & InteropLib::Com::CreateObjectFlags_Aggregated) + { + _ASSERTE(inner != nullptr); + contextLocal->HandleReferenceTrackerAggregation(); + } } *context = contextLocal; @@ -732,21 +755,48 @@ void NativeObjectWrapperContext::Destroy(_In_ NativeObjectWrapperContext* wrappe { _ASSERTE(wrapper != nullptr); + // Check if the tracker object manager should be informed prior to being destroyed. + IReferenceTracker* trackerMaybe = wrapper->GetReferenceTracker(); + if (trackerMaybe != nullptr) + { + // We only call this during a GC so ignore the failure as + // there is no way we can handle it at this point. + HRESULT hr = TrackerObjectManager::BeforeWrapperDestroyed(trackerMaybe); + _ASSERTE(SUCCEEDED(hr)); + (void)hr; + } + // Manually trigger the destructor since placement // new was used to allocate the object. wrapper->~NativeObjectWrapperContext(); InteropLibImports::MemFree(wrapper, AllocScenario::NativeObjectWrapper); } -NativeObjectWrapperContext::NativeObjectWrapperContext(_In_ void* runtimeContext, _In_opt_ IReferenceTracker* trackerObject) +namespace +{ + // State ownership mechanism. + enum : int + { + TrackerObjectState_NotSet = 0, + TrackerObjectState_SetNoRelease = 1, + TrackerObjectState_SetForRelease = 2, + }; +} + +NativeObjectWrapperContext::NativeObjectWrapperContext( + _In_ void* runtimeContext, + _In_opt_ IReferenceTracker* trackerObject, + _In_opt_ IUnknown* nativeObjectAsInner) : _trackerObject{ trackerObject } , _runtimeContext{ runtimeContext } - , _isValidTracker{ (trackerObject != nullptr ? TRUE : FALSE) } + , _trackerObjectDisconnected{ FALSE } + , _trackerObjectState{ (trackerObject == nullptr ? TrackerObjectState_NotSet : TrackerObjectState_SetForRelease) } + , _nativeObjectAsInner{ nativeObjectAsInner } #ifdef _DEBUG , _sentinel{ LiveContextSentinel } #endif { - if (_isValidTracker == TRUE) + if (_trackerObjectState == TrackerObjectState_SetForRelease) (void)_trackerObject->AddRef(); } @@ -754,6 +804,10 @@ NativeObjectWrapperContext::~NativeObjectWrapperContext() { DisconnectTracker(); + // If the inner was supplied, we need to release our reference. + if (_nativeObjectAsInner != nullptr) + (void)_nativeObjectAsInner->Release(); + #ifdef _DEBUG _sentinel = DeadContextSentinel; #endif @@ -766,12 +820,43 @@ void* NativeObjectWrapperContext::GetRuntimeContext() const noexcept IReferenceTracker* NativeObjectWrapperContext::GetReferenceTracker() const noexcept { - return ((_isValidTracker == TRUE) ? _trackerObject : nullptr); + return ((_trackerObjectState == TrackerObjectState_NotSet) ? nullptr : _trackerObject); } +// See TrackerObjectManager::AfterWrapperCreated() for AddRefFromTrackerSource() usage. +// See NativeObjectWrapperContext::HandleReferenceTrackerAggregation() for additional +// cleanup logistics. void NativeObjectWrapperContext::DisconnectTracker() noexcept { - // Attempt to disconnect from the tracker. - if (TRUE == ::InterlockedCompareExchange((LONG*)&_isValidTracker, FALSE, TRUE)) + // Return if already disconnected or the tracker isn't set. + if (FALSE != ::InterlockedCompareExchange((LONG*)&_trackerObjectDisconnected, TRUE, FALSE) + || _trackerObjectState == TrackerObjectState_NotSet) + { + return; + } + + _ASSERTE(_trackerObject != nullptr); + + // Always release the tracker source during a disconnect. + // This to account for the implied IUnknown ownership by the runtime. + (void)_trackerObject->ReleaseFromTrackerSource(); // IUnknown + + // Disconnect from the tracker. + if (_trackerObjectState == TrackerObjectState_SetForRelease) + { + (void)_trackerObject->ReleaseFromTrackerSource(); // IReferenceTracker (void)_trackerObject->Release(); + } +} + +void NativeObjectWrapperContext::HandleReferenceTrackerAggregation() noexcept +{ + _ASSERTE(_trackerObjectState == TrackerObjectState_SetForRelease && _trackerObject != nullptr); + + // Aggregation with an IReferenceTracker instance creates an extra AddRef() + // on the outer (e.g. MOW) so we clean up that issue here. + _trackerObjectState = TrackerObjectState_SetNoRelease; + + (void)_trackerObject->ReleaseFromTrackerSource(); // IReferenceTracker + (void)_trackerObject->Release(); } diff --git a/src/coreclr/src/interop/comwrappers.hpp b/src/coreclr/src/interop/comwrappers.hpp index 3ae91d8a88c7d..e4d849a562574 100644 --- a/src/coreclr/src/interop/comwrappers.hpp +++ b/src/coreclr/src/interop/comwrappers.hpp @@ -82,10 +82,6 @@ class ManagedObjectWrapper ~ManagedObjectWrapper(); - // Represents a single implementation of how to release - // the wrapper. Supplied with a decrementing value. - ULONGLONG UniversalRelease(_In_ ULONGLONG dec); - // Query the runtime defined tables. void* AsRuntimeDefined(_In_ REFIID riid); @@ -102,8 +98,8 @@ class ManagedObjectWrapper void SetFlag(_In_ CreateComInterfaceFlagsEx flag); void ResetFlag(_In_ CreateComInterfaceFlagsEx flag); - // Used while validating wrapper is active. - ULONG IsActiveAddRef(); + // Indicate if the wrapper should be considered a GC root. + bool IsRooted() const; public: // IReferenceTrackerTarget ULONG AddRefFromReferenceTracker(); @@ -139,7 +135,9 @@ class NativeObjectWrapperContext { IReferenceTracker* _trackerObject; void* _runtimeContext; - Volatile _isValidTracker; + Volatile _trackerObjectDisconnected; + int _trackerObjectState; + IUnknown* _nativeObjectAsInner; #ifdef _DEBUG size_t _sentinel; @@ -151,6 +149,7 @@ class NativeObjectWrapperContext // Create a NativeObjectWrapperContext instance static HRESULT NativeObjectWrapperContext::Create( _In_ IUnknown* external, + _In_opt_ IUnknown* nativeObjectAsInner, _In_ InteropLib::Com::CreateObjectFlags flags, _In_ size_t runtimeContextSize, _Outptr_ NativeObjectWrapperContext** context); @@ -159,7 +158,7 @@ class NativeObjectWrapperContext static void Destroy(_In_ NativeObjectWrapperContext* wrapper); private: - NativeObjectWrapperContext(_In_ void* runtimeContext, _In_opt_ IReferenceTracker* trackerObject); + NativeObjectWrapperContext(_In_ void* runtimeContext, _In_opt_ IReferenceTracker* trackerObject, _In_opt_ IUnknown* nativeObjectAsInner); ~NativeObjectWrapperContext(); public: @@ -171,6 +170,9 @@ class NativeObjectWrapperContext // Disconnect reference tracker instance. void DisconnectTracker() noexcept; + +private: + void HandleReferenceTrackerAggregation() noexcept; }; // Manage native object wrappers that support IReferenceTracker. diff --git a/src/coreclr/src/interop/inc/interoplib.h b/src/coreclr/src/interop/inc/interoplib.h index a1f32b99ecdb2..39ceadfbb809c 100644 --- a/src/coreclr/src/interop/inc/interoplib.h +++ b/src/coreclr/src/interop/inc/interoplib.h @@ -38,11 +38,8 @@ namespace InteropLib // Destroy the supplied wrapper void DestroyWrapperForObject(_In_ void* wrapper) noexcept; - // Check if a wrapper is active. - HRESULT IsActiveWrapper(_In_ IUnknown* wrapper) noexcept; - - // Reactivate the supplied wrapper. - HRESULT ReactivateWrapper(_In_ IUnknown* wrapper, _In_ InteropLib::OBJECTHANDLE handle) noexcept; + // Check if a wrapper is considered a GC root. + HRESULT IsWrapperRooted(_In_ IUnknown* wrapper) noexcept; // Get the object for the supplied wrapper HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept; @@ -58,6 +55,9 @@ namespace InteropLib // See https://docs.microsoft.com/windows/win32/api/windows.ui.xaml.hosting.referencetracker/ // for details. bool FromTrackerRuntime; + + // The supplied external object is wrapping a managed object. + bool ManagedObjectWrapper; }; // See CreateObjectFlags in ComWrappers.cs @@ -66,13 +66,21 @@ namespace InteropLib CreateObjectFlags_None = 0, CreateObjectFlags_TrackerObject = 1, CreateObjectFlags_UniqueInstance = 2, + CreateObjectFlags_Aggregated = 4, }; + // Get the true identity for the supplied IUnknown. + HRESULT GetIdentityForCreateWrapperForExternal( + _In_ IUnknown* external, + _In_ enum CreateObjectFlags flags, + _Outptr_ IUnknown** identity) noexcept; + // Allocate a wrapper context for an external object. // The runtime supplies the external object, flags, and a memory // request in order to bring the object into the runtime. HRESULT CreateWrapperForExternal( _In_ IUnknown* external, + _In_opt_ IUnknown* inner, _In_ enum CreateObjectFlags flags, _In_ size_t contextSize, _Out_ ExternalWrapperResult* result) noexcept; diff --git a/src/coreclr/src/interop/interoplib.cpp b/src/coreclr/src/interop/interoplib.cpp index f730817dea317..9aff8c2bb335c 100644 --- a/src/coreclr/src/interop/interoplib.cpp +++ b/src/coreclr/src/interop/interoplib.cpp @@ -54,53 +54,26 @@ namespace InteropLib ManagedObjectWrapper::Destroy(wrapper); } - HRESULT IsActiveWrapper(_In_ IUnknown* wrapperMaybe) noexcept + HRESULT IsWrapperRooted(_In_ IUnknown* wrapperMaybe) noexcept { ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); if (wrapper == nullptr) return E_INVALIDARG; - ULONG count = wrapper->IsActiveAddRef(); - if (count == 1 || wrapper->Target == nullptr) - { - // The wrapper isn't active. - (void)wrapper->Release(); - return S_FALSE; - } - - return S_OK; - } - - HRESULT ReactivateWrapper(_In_ IUnknown* wrapperMaybe, _In_ OBJECTHANDLE handle) noexcept - { - ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); - if (wrapper == nullptr || handle == nullptr) - return E_INVALIDARG; - - // Take an AddRef() as an indication of ownership. - (void)wrapper->AddRef(); - - // If setting this object handle fails, then the race - // was lost and we will cleanup the handle. - if (!wrapper->TrySetObjectHandle(handle)) - InteropLibImports::DeleteObjectInstanceHandle(handle); - - return S_OK; + return wrapper->IsRooted() ? S_OK : S_FALSE; } HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept { - if (object == nullptr) - return E_POINTER; - + _ASSERTE(wrapper != nullptr && object != nullptr); *object = nullptr; - HRESULT hr = IsActiveWrapper(wrapper); - if (hr != S_OK) - return hr; - + // Attempt to get the managed object wrapper. ManagedObjectWrapper *mow = ManagedObjectWrapper::MapFromIUnknown(wrapper); - _ASSERTE(mow != nullptr); + if (mow == nullptr) + return E_INVALIDARG; + + (void)mow->AddRef(); *object = mow->Target; return S_OK; @@ -125,8 +98,43 @@ namespace InteropLib return wrapper->IsSet(CreateComInterfaceFlagsEx::IsComActivated) ? S_OK : S_FALSE; } + HRESULT GetIdentityForCreateWrapperForExternal( + _In_ IUnknown* external, + _In_ enum CreateObjectFlags flags, + _Outptr_ IUnknown** identity) noexcept + { + _ASSERTE(external != nullptr && identity != nullptr); + + IUnknown* checkForIdentity = external; + + // Check if the flags indicate we are creating + // an object for an external IReferenceTracker instance + // that we are aggregating with. + bool refTrackerInnerScenario = (flags & CreateObjectFlags_TrackerObject) + && (flags & CreateObjectFlags_Aggregated); + + ComHolder trackerObject; + if (refTrackerInnerScenario) + { + // We are checking the supplied external value + // for IReferenceTracker since in .NET 5 this could + // actually be the inner and we want the true identity + // not the inner . This is a trick since the only way + // to get identity from an inner is through a non-IUnknown + // interface QI. Once we have the IReferenceTracker + // instance we can be sure the QI for IUnknown will really + // be the true identity. + HRESULT hr = external->QueryInterface(&trackerObject); + if (SUCCEEDED(hr)) + checkForIdentity = trackerObject.p; + } + + return checkForIdentity->QueryInterface(identity); + } + HRESULT CreateWrapperForExternal( _In_ IUnknown* external, + _In_opt_ IUnknown* inner, _In_ enum CreateObjectFlags flags, _In_ size_t contextSize, _Out_ ExternalWrapperResult* result) noexcept @@ -136,10 +144,11 @@ namespace InteropLib HRESULT hr; NativeObjectWrapperContext* wrapperContext; - RETURN_IF_FAILED(NativeObjectWrapperContext::Create(external, flags, contextSize, &wrapperContext)); + RETURN_IF_FAILED(NativeObjectWrapperContext::Create(external, inner, flags, contextSize, &wrapperContext)); result->Context = wrapperContext->GetRuntimeContext(); result->FromTrackerRuntime = (wrapperContext->GetReferenceTracker() != nullptr); + result->ManagedObjectWrapper = (ManagedObjectWrapper::MapFromIUnknown(external) != nullptr); return S_OK; } @@ -150,17 +159,6 @@ namespace InteropLib // A caller should not be destroying a context without knowing if the context is valid. _ASSERTE(context != nullptr); - // Check if the tracker object manager should be informed prior to being destroyed. - IReferenceTracker* trackerMaybe = context->GetReferenceTracker(); - if (trackerMaybe != nullptr) - { - // We only call this during a GC so ignore the failure as - // there is no way we can handle it at this point. - HRESULT hr = TrackerObjectManager::BeforeWrapperDestroyed(trackerMaybe); - _ASSERTE(SUCCEEDED(hr)); - (void)hr; - } - NativeObjectWrapperContext::Destroy(context); } diff --git a/src/coreclr/src/interop/trackerobjectmanager.cpp b/src/coreclr/src/interop/trackerobjectmanager.cpp index 91cc894c0eef0..f205484d3b0af 100644 --- a/src/coreclr/src/interop/trackerobjectmanager.cpp +++ b/src/coreclr/src/interop/trackerobjectmanager.cpp @@ -296,7 +296,8 @@ HRESULT TrackerObjectManager::AfterWrapperCreated(_In_ IReferenceTracker* obj) // Send out AddRefFromTrackerSource callbacks to notify tracker runtime we've done AddRef() // for certain interfaces. We should do this *after* we made a AddRef() because we should never // be in a state where report refs > actual refs - RETURN_IF_FAILED(obj->AddRefFromTrackerSource()); + RETURN_IF_FAILED(obj->AddRefFromTrackerSource()); // IUnknown + RETURN_IF_FAILED(obj->AddRefFromTrackerSource()); // IReferenceTracker return S_OK; } diff --git a/src/coreclr/src/vm/gcenv.ee.cpp b/src/coreclr/src/vm/gcenv.ee.cpp index d58be8634d28f..5966e56eccdd7 100644 --- a/src/coreclr/src/vm/gcenv.ee.cpp +++ b/src/coreclr/src/vm/gcenv.ee.cpp @@ -309,10 +309,16 @@ bool GCToEEInterface::RefCountedHandleCallbacks(Object * pObject) //@todo optimize the access to the ref-count ComCallWrapper* pWrap = ComCallWrapper::GetWrapperForObject((OBJECTREF)pObject); - return pWrap != NULL && pWrap->IsWrapperActive(); -#else - return false; + if (pWrap != NULL && pWrap->IsWrapperActive()) + return true; #endif +#ifdef FEATURE_COMWRAPPERS + bool isRooted = false; + if (ComWrappersNative::HasManagedObjectComWrapper((OBJECTREF)pObject, &isRooted)) + return isRooted; +#endif + + return false; } void GCToEEInterface::GcBeforeBGCSweepWork() diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index f77bdc428c3d7..ce181ce9ccd10 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -442,7 +442,7 @@ namespace INT64 g_trackerSupportGlobalInstanceId = ComWrappersNative::InvalidWrapperId; // Defined handle types for the specific object uses. - const HandleType InstanceHandleType{ HNDTYPE_STRONG }; + const HandleType InstanceHandleType{ HNDTYPE_REFCOUNTED }; // Scenarios for ComWrappers usage. // These values should match the managed definition in ComWrappers. @@ -655,19 +655,9 @@ namespace } else if (wrapperRawMaybe != NULL) { - // It is possible the supplied wrapper is no longer valid. If so, reactivate the - // wrapper using the protected OBJECTREF. + // AddRef() the existing wrapper. IUnknown* wrapper = static_cast(wrapperRawMaybe); - hr = InteropLib::Com::IsActiveWrapper(wrapper); - if (hr == S_FALSE) - { - STRESS_LOG1(LF_INTEROP, LL_INFO100, "Reactivating MOW: 0x%p\n", wrapperRawMaybe); - OBJECTHANDLE h = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType); - hr = InteropLib::Com::ReactivateWrapper(wrapper, static_cast(h)); - } - - if (FAILED(hr)) - COMPlusThrowHR(hr); + (void)wrapper->AddRef(); } GCPROTECT_END(); @@ -680,6 +670,7 @@ namespace _In_opt_ OBJECTREF impl, _In_ INT64 wrapperId, _In_ IUnknown* identity, + _In_opt_ IUnknown* inner, _In_ CreateObjectFlags flags, _In_ ComWrappersScenario scenario, _In_opt_ OBJECTREF wrapperMaybe, @@ -760,6 +751,7 @@ namespace GCX_PREEMP(); hr = InteropLib::Com::CreateWrapperForExternal( identity, + inner, flags, sizeof(ExternalObjectContext), &resultHolder); @@ -783,7 +775,7 @@ namespace if (gc.objRefMaybe != NULL) { // Construct the new context with the object details. - DWORD flags = (resultHolder.Result.FromTrackerRuntime + DWORD eocFlags = (resultHolder.Result.FromTrackerRuntime ? ExternalObjectContext::Flags_ReferenceTracker : ExternalObjectContext::Flags_None) | (uniqueInstance @@ -795,7 +787,7 @@ namespace GetCurrentCtxCookie(), gc.objRefMaybe->GetSyncBlockIndex(), wrapperId, - flags); + eocFlags); if (uniqueInstance) { @@ -833,6 +825,18 @@ namespace // Detach from the holder to avoid cleanup. (void)resultHolder.DetachContext(); STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created EOC (Unique Instance: %d): 0x%p\n", (int)uniqueInstance, extObjCxt); + + // If this is an aggregation scenario and the identity object + // is a managed object wrapper, we need to call Release() to + // indicate this external object isn't rooted. In the event the + // object is passed out to native code an AddRef() must be called + // based on COM convention and will "fix" the count. + if (flags & CreateObjectFlags::CreateObjectFlags_Aggregated + && resultHolder.Result.ManagedObjectWrapper) + { + (void)identity->Release(); + STRESS_LOG1(LF_INTEROP, LL_INFO100, "EOC aggregated with MOW: 0x%p\n", identity); + } } _ASSERTE(extObjCxt->IsActive()); @@ -1086,6 +1090,7 @@ namespace InteropLibImports gc.implRef, g_trackerSupportGlobalInstanceId, externalComObject, + NULL, externalObjectFlags, ComWrappersScenario::TrackerSupportGlobalInstance, gc.wrapperMaybeRef, @@ -1252,9 +1257,12 @@ namespace InteropLibImports ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle); OBJECTREF target = ObjectFromHandle(objectHandle); - // If these point at the same object don't create a reference. - if (source->PassiveGetSyncBlock() == target->PassiveGetSyncBlock()) + // Return if the target has been collected or these are the same object. + if (target == NULL + || source->PassiveGetSyncBlock() == target->PassiveGetSyncBlock()) + { return S_FALSE; + } STRESS_LOG2(LF_INTEROP, LL_INFO1000, "Found reference path: 0x%p => 0x%p\n", OBJECTREFToObject(source), @@ -1317,9 +1325,22 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( // Determine the true identity of the object SafeComHolder identity; - hr = externalComObject->QueryInterface(IID_IUnknown, &identity); + hr = InteropLib::Com::GetIdentityForCreateWrapperForExternal( + externalComObject, + (CreateObjectFlags)flags, + &identity); _ASSERTE(hr == S_OK); + // Customized inners are only supported in aggregation with + // IReferenceTracker scenarios (e.g. WinRT). + IUnknown* inner = NULL; + if ((externalComObject != identity) + && (flags & CreateObjectFlags::CreateObjectFlags_TrackerObject) + && (flags & CreateObjectFlags::CreateObjectFlags_Aggregated)) + { + inner = externalComObject; + } + // Switch to Cooperative mode since object references // are being manipulated. { @@ -1330,6 +1351,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), wrapperId, identity, + inner, (CreateObjectFlags)flags, ComWrappersScenario::Instance, ObjectToOBJECTREF(*wrapperMaybe.m_ppObject), @@ -1368,6 +1390,7 @@ void ComWrappersNative::DestroyManagedObjectComWrapper(_In_ void* wrapper) CONTRACTL { NOTHROW; + GC_NOTRIGGER; MODE_ANY; PRECONDITION(wrapper != NULL); } @@ -1382,6 +1405,7 @@ void ComWrappersNative::DestroyExternalComObjectContext(_In_ void* contextRaw) CONTRACTL { NOTHROW; + GC_NOTRIGGER; MODE_ANY; PRECONDITION(contextRaw != NULL); } @@ -1508,6 +1532,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance( NULL /*comWrappersImpl*/, g_marshallingGlobalInstanceId, identity, + NULL, (CreateObjectFlags)flags, ComWrappersScenario::MarshallingGlobalInstance, NULL /*wrapperMaybe*/, @@ -1581,6 +1606,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance( NULL /*comWrappersImpl*/, g_trackerSupportGlobalInstanceId, identity, + NULL, CreateObjectFlags::CreateObjectFlags_TrackerObject, ComWrappersScenario::TrackerSupportGlobalInstance, NULL /*wrapperMaybe*/, @@ -1632,6 +1658,66 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE return nullptr; } +namespace +{ + struct CallbackContext + { + bool HasWrapper; + bool IsRooted; + }; + bool IsWrapperRootedCallback(_In_ void* mocw, _In_ void* cxtRaw) + { + CONTRACTL + { + NOTHROW; + GC_NOTRIGGER; + MODE_ANY; + PRECONDITION(mocw != NULL); + PRECONDITION(cxtRaw != NULL); + } + CONTRACTL_END; + + auto cxt = static_cast(cxtRaw); + cxt->HasWrapper = true; + + IUnknown* wrapper = static_cast(mocw); + cxt->IsRooted = (InteropLib::Com::IsWrapperRooted(wrapper) == S_OK); + + // If we find a single rooted wrapper then the managed object + // is considered rooted and we can stop enumerating. + if (cxt->IsRooted) + return false; + + return true; + } +} + +bool ComWrappersNative::HasManagedObjectComWrapper(_In_ OBJECTREF object, _Out_ bool* isRooted) +{ + CONTRACTL + { + NOTHROW; + GC_NOTRIGGER; + PRECONDITION(CheckPointer(isRooted)); + } + CONTRACTL_END; + + *isRooted = false; + SyncBlock* syncBlock = object->PassiveGetSyncBlock(); + if (syncBlock == nullptr) + return false; + + InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfoNoCreate(); + if (interopInfo == nullptr) + return false; + + CallbackContext cxt{}; + interopInfo->EnumManagedObjectComWrappers(&IsWrapperRootedCallback, &cxt); + + *isRooted = cxt.IsRooted; + return cxt.HasWrapper; +} + #endif // FEATURE_COMWRAPPERS void Interop::OnGCStarted(_In_ int nCondemnedGeneration) diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h index a1f34e6557531..3d132a18ad6ef 100644 --- a/src/coreclr/src/vm/interoplibinterface.h +++ b/src/coreclr/src/vm/interoplibinterface.h @@ -44,6 +44,7 @@ class ComWrappersNative public: // Unwrapping support static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId); + static bool HasManagedObjectComWrapper(_In_ OBJECTREF object, _Out_ bool* isActive); }; class GlobalComWrappersForMarshalling diff --git a/src/coreclr/src/vm/syncblk.h b/src/coreclr/src/vm/syncblk.h index 48d5b85b4f4ce..79cb2c7523156 100644 --- a/src/coreclr/src/vm/syncblk.h +++ b/src/coreclr/src/vm/syncblk.h @@ -817,7 +817,8 @@ class InteropSyncBlockInfo if (FastInterlockCompareExchangePointer((ManagedObjectComWrapperByIdMap**)&m_managedObjectComWrapperMap, (ManagedObjectComWrapperByIdMap *)map, NULL) == NULL) { map.SuppressRelease(); - m_managedObjectComWrapperLock.Init(CrstLeafLock); + // The GC thread does enumerate these objects so add CRST_UNSAFE_COOPGC. + m_managedObjectComWrapperLock.Init(CrstInteropData, CRST_UNSAFE_COOPGC); } _ASSERTE(m_managedObjectComWrapperMap != NULL); @@ -832,8 +833,8 @@ class InteropSyncBlockInfo return true; } - using EnumWrappersCallback = void(void* mocw); - void ClearManagedObjectComWrappers(EnumWrappersCallback* callback) + using ClearWrappersCallback = void(void* mocw); + void ClearManagedObjectComWrappers(ClearWrappersCallback* callback) { LIMITED_METHOD_CONTRACT; @@ -854,6 +855,27 @@ class InteropSyncBlockInfo m_managedObjectComWrapperMap->RemoveAll(); } + + using EnumWrappersCallback = bool(void* mocw, void* cxt); + void EnumManagedObjectComWrappers(EnumWrappersCallback* callback, void* cxt) + { + LIMITED_METHOD_CONTRACT; + + _ASSERTE(callback != NULL); + + if (m_managedObjectComWrapperMap == NULL) + return; + + CrstHolder lock(&m_managedObjectComWrapperLock); + + ManagedObjectComWrapperByIdMap::Iterator iter = m_managedObjectComWrapperMap->Begin(); + while (iter != m_managedObjectComWrapperMap->End()) + { + if (!callback(iter->Value(), cxt)) + break; + ++iter; + } + } #endif // !DACCESS_COMPILE bool TryGetExternalComObjectContext(_Out_ void** eoc) diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs index ad6e3b73c0e9b..29a3b7b5b9f5a 100644 --- a/src/tests/Interop/COM/ComWrappers/API/Program.cs +++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs @@ -18,31 +18,34 @@ class TestComWrappers : ComWrappers { protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) { - Assert.IsTrue(obj is Test); - IntPtr fpQueryInteface = default; IntPtr fpAddRef = default; IntPtr fpRelease = default; ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); - var vtbl = new ITestVtbl() + ComInterfaceEntry* entryRaw = null; + count = 0; + if (obj is Test) { - IUnknownImpl = new IUnknownVtbl() + var vtbl = new ITestVtbl() { - QueryInterface = fpQueryInteface, - AddRef = fpAddRef, - Release = fpRelease - }, - SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue) - }; - var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl)); - Marshal.StructureToPtr(vtbl, vtblRaw, false); - - var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry)); - entryRaw->IID = typeof(ITest).GUID; - entryRaw->Vtable = vtblRaw; + IUnknownImpl = new IUnknownVtbl() + { + QueryInterface = fpQueryInteface, + AddRef = fpAddRef, + Release = fpRelease + }, + SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue) + }; + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); + + entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry)); + entryRaw->IID = typeof(ITest).GUID; + entryRaw->Vtable = vtblRaw; + count = 1; + } - count = 1; return entryRaw; } @@ -75,6 +78,19 @@ public static void ValidateIUnknownImpls() } } + static void ForceGC() + { + // Trigger the GC multiple times and then + // wait for all finalizers since that is where + // most of the cleanup occurs. + GC.Collect(); + GC.Collect(); + GC.Collect(); + GC.Collect(); + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + static void ValidateComInterfaceCreation() { Console.WriteLine($"Running {nameof(ValidateComInterfaceCreation)}..."); @@ -375,11 +391,7 @@ static void ValidateRuntimeTrackerScenario() Assert.IsTrue(testWrapperIds.Count <= Test.InstanceCount); - GC.Collect(); - GC.Collect(); - GC.Collect(); - GC.Collect(); - GC.Collect(); + ForceGC(); Assert.IsTrue(testWrapperIds.Count <= Test.InstanceCount); @@ -391,11 +403,69 @@ static void ValidateRuntimeTrackerScenario() testWrapperIds.Clear(); - GC.Collect(); - GC.Collect(); - GC.Collect(); - GC.Collect(); - GC.Collect(); + ForceGC(); + } + + unsafe class Derived : ITrackerObjectWrapper + { + public Derived(ComWrappers cw, bool aggregateRefTracker) + : base(cw, aggregateRefTracker) + { } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static WeakReference AllocateAndUseBaseType(ComWrappers cw, bool aggregateRefTracker) + { + var derived = new Derived(cw, aggregateRefTracker); + + // Use the base type + IntPtr testWrapper = cw.GetOrCreateComInterfaceForObject(new Test(), CreateComInterfaceFlags.TrackerSupport); + int id = derived.AddObjectRef(testWrapper); + + // Tell the tracker runtime to release its hold on the base instance. + MockReferenceTrackerRuntime.ReleaseAllTrackerObjects(); + + // Validate the GC is tracking the entire Derived type. + ForceGC(); + + derived.DropObjectRef(id); + + return new WeakReference(derived); + } + } + + static void ValidateAggregationWithComObject() + { + Console.WriteLine($"Running {nameof(ValidateAggregationWithComObject)}..."); + + using var allocTracker = MockReferenceTrackerRuntime.CountTrackerObjectAllocations(); + var cw = new TestComWrappers(); + WeakReference weakRef = Derived.AllocateAndUseBaseType(cw, aggregateRefTracker: false); + + ForceGC(); + + // Validate all instances were cleaned up + Assert.IsFalse(weakRef.TryGetTarget(out _)); + Assert.AreEqual(0, allocTracker.GetCount()); + } + + static void ValidateAggregationWithReferenceTrackerObject() + { + Console.WriteLine($"Running {nameof(ValidateAggregationWithReferenceTrackerObject)}..."); + + using var allocTracker = MockReferenceTrackerRuntime.CountTrackerObjectAllocations(); + var cw = new TestComWrappers(); + WeakReference weakRef = Derived.AllocateAndUseBaseType(cw, aggregateRefTracker: true); + + ForceGC(); + + // Validate all instances were cleaned up. + Assert.IsFalse(weakRef.TryGetTarget(out _)); + + // Reference counter cleanup requires additional GCs since the Finalizer is used + // to clean up the Reference Tracker runtime references. + ForceGC(); + + Assert.AreEqual(0, allocTracker.GetCount()); } static int Main(string[] doNotUse) @@ -410,6 +480,8 @@ static int Main(string[] doNotUse) ValidateIUnknownImpls(); ValidateBadComWrapperImpl(); ValidateRuntimeTrackerScenario(); + ValidateAggregationWithComObject(); + ValidateAggregationWithReferenceTrackerObject(); } catch (Exception e) { diff --git a/src/tests/Interop/COM/ComWrappers/Common.cs b/src/tests/Interop/COM/ComWrappers/Common.cs index 8b2aa4020b2bd..083b4faa55d55 100644 --- a/src/tests/Interop/COM/ComWrappers/Common.cs +++ b/src/tests/Interop/COM/ComWrappers/Common.cs @@ -4,6 +4,7 @@ namespace ComWrappersTests.Common { using System; + using System.Threading; using System.Runtime.InteropServices; // @@ -84,10 +85,64 @@ public static int SetValueInternal(IntPtr dispatchPtr, int i) // // Native interface defintion with managed wrapper for tracker object // - struct MockReferenceTrackerRuntime + sealed class MockReferenceTrackerRuntime { + private static readonly ReaderWriterLockSlim AllocLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion); + + public static IntPtr CreateTrackerObject() + { + return CreateTrackerObject(IntPtr.Zero, out IntPtr _); + } + + public static IntPtr CreateTrackerObject(IntPtr outer, out IntPtr inner) + { + AllocLock.EnterReadLock(); + try + { + return CreateTrackerObject_Unsafe(outer, out inner); + } + finally + { + AllocLock.ExitReadLock(); + } + } + [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static IntPtr CreateTrackerObject(); + extern private static IntPtr CreateTrackerObject_Unsafe(IntPtr outer, out IntPtr inner); + + public class AllocationCountResult : IDisposable + { + private bool isDisposed = false; + private ReaderWriterLockSlim allocLock; + public AllocationCountResult(ReaderWriterLockSlim allocLock) + { + this.allocLock = allocLock; + this.allocLock.EnterWriteLock(); + StartTrackerObjectAllocationCount_Unsafe(); + } + + public int GetCount() => StopTrackerObjectAllocationCount_Unsafe(); + + void IDisposable.Dispose() + { + if (this.isDisposed) + return; + + this.allocLock.ExitWriteLock(); + this.isDisposed = true; + } + } + + public static AllocationCountResult CountTrackerObjectAllocations() + { + return new AllocationCountResult(AllocLock); + } + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern private static void StartTrackerObjectAllocationCount_Unsafe(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern private static int StopTrackerObjectAllocationCount_Unsafe(); [DllImport(nameof(MockReferenceTrackerRuntime))] extern public static void ReleaseAllTrackerObjects(); @@ -108,7 +163,7 @@ public struct VtblPtr public IntPtr Vtbl; } - public class ITrackerObjectWrapper : ITrackerObject + public class ITrackerObjectWrapper : ITrackerObject, ICustomQueryInterface { private struct ITrackerObjectWrapperVtbl { @@ -124,28 +179,40 @@ private struct ITrackerObjectWrapperVtbl private delegate int _AddObjectRef(IntPtr This, IntPtr obj, out int id); private delegate int _DropObjectRef(IntPtr This, int id); - private readonly IntPtr instance; + private ComWrappersHelper.ClassNative classNative; + private readonly ITrackerObjectWrapperVtbl vtable; - public ITrackerObjectWrapper(IntPtr instance) + public ITrackerObjectWrapper(IntPtr instancePtr) { - var inst = Marshal.PtrToStructure(instance); + var inst = Marshal.PtrToStructure(instancePtr); this.vtable = Marshal.PtrToStructure(inst.Vtbl); - this.instance = instance; + this.classNative.Instance = instancePtr; + this.classNative.Release = ComWrappersHelper.ReleaseFlags.Instance; } - ~ITrackerObjectWrapper() + protected unsafe ITrackerObjectWrapper(ComWrappers cw, bool aggregateRefTracker) { - if (this.instance != IntPtr.Zero) + ComWrappersHelper.Init(ref this.classNative, this, aggregateRefTracker, cw, &CreateInstance); + + var inst = Marshal.PtrToStructure(this.classNative.Instance); + this.vtable = Marshal.PtrToStructure(inst.Vtbl); + + static IntPtr CreateInstance(IntPtr outer, out IntPtr inner) { - this.vtable.Release(this.instance); + return MockReferenceTrackerRuntime.CreateTrackerObject(outer, out inner); } } + ~ITrackerObjectWrapper() + { + ComWrappersHelper.Cleanup(ref this.classNative); + } + public int AddObjectRef(IntPtr obj) { int id; - int hr = this.vtable.AddObjectRef(this.instance, obj, out id); + int hr = this.vtable.AddObjectRef(this.classNative.Instance, obj, out id); if (hr != 0) { throw new COMException($"{nameof(AddObjectRef)}", hr); @@ -156,12 +223,255 @@ public int AddObjectRef(IntPtr obj) public void DropObjectRef(int id) { - int hr = this.vtable.DropObjectRef(this.instance, id); + int hr = this.vtable.DropObjectRef(this.classNative.Instance, id); if (hr != 0) { throw new COMException($"{nameof(DropObjectRef)}", hr); } } + + CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv) + { + if (this.classNative.Inner == IntPtr.Zero) + { + ppv = IntPtr.Zero; + return CustomQueryInterfaceResult.NotHandled; + } + + const int S_OK = 0; + const int E_NOINTERFACE = unchecked((int)0x80004002); + + int hr = Marshal.QueryInterface(this.classNative.Inner, ref iid, out ppv); + if (hr == S_OK) + { + return CustomQueryInterfaceResult.Handled; + } + + return hr == E_NOINTERFACE + ? CustomQueryInterfaceResult.NotHandled + : CustomQueryInterfaceResult.Failed; + } + } + + class ComWrappersHelper + { + private static Guid IID_IReferenceTracker = new Guid("11d3b13a-180e-4789-a8be-7712882893e6"); + + [Flags] + public enum ReleaseFlags + { + None = 0, + Instance = 1, + Inner = 2, + ReferenceTracker = 4 + } + + public struct ClassNative + { + public ReleaseFlags Release; + public IntPtr Instance; + public IntPtr Inner; + public IntPtr ReferenceTracker; + } + + public unsafe static void Init( + ref ClassNative classNative, + object thisInstance, + bool aggregateRefTracker, + ComWrappers cw, + delegate* CreateInstance) + { + bool isAggregation = typeof(T) != thisInstance.GetType(); + + { + IntPtr outer = default; + if (isAggregation) + { + // Create a managed object wrapper (i.e. CCW) to act as the outer. + // Passing the CreateComInterfaceFlags.TrackerSupport can be done if + // IReferenceTracker support is possible. + // + // The outer is now owned in this context. + outer = cw.GetOrCreateComInterfaceForObject(thisInstance, CreateComInterfaceFlags.TrackerSupport); + } + + // Create an instance of the COM/WinRT type. + // This is typically accomplished through a call to CoCreateInstance() or RoActivateInstance(). + // + // Ownership of the outer has been transferred to the new instance. + // Some APIs do return a non-null inner even with a null outer. This + // means ownership may now be owned in this context in either aggregation state. + classNative.Instance = CreateInstance(outer, out classNative.Inner); + } + + // TEST: Indicate if we should attempt aggregation with ReferenceTracker. + if (aggregateRefTracker) + { + // Determine if the instance supports IReferenceTracker (e.g. WinUI). + // Acquiring this interface is useful for: + // 1) Providing an indication of what value to pass during RCW creation. + // 2) Informing the Reference Tracker runtime during non-aggregation + // scenarios about new references. + // + // If aggregation, query the inner since that will have the implementation + // otherwise the new instance will be used. Since the inner was composed + // it should answer immediately without going through the outer. Either way + // the reference count will go to the new instance. + IntPtr queryForTracker = isAggregation ? classNative.Inner : classNative.Instance; + int hr = Marshal.QueryInterface(queryForTracker, ref IID_IReferenceTracker, out classNative.ReferenceTracker); + if (hr != 0) + { + classNative.ReferenceTracker = default; + } + } + + { + // Determine flags needed for native object wrapper (i.e. RCW) creation. + var createObjectFlags = CreateObjectFlags.None; + IntPtr instanceToWrap = classNative.Instance; + + // Update flags if the native instance is being used in an aggregation scenario. + if (isAggregation) + { + // Indicate the scenario is aggregation + createObjectFlags |= (CreateObjectFlags)4; + + // The instance supports IReferenceTracker. + if (classNative.ReferenceTracker != default(IntPtr)) + { + createObjectFlags |= CreateObjectFlags.TrackerObject; + + // IReferenceTracker is not needed in aggregation scenarios. + // It is not needed because all QueryInterface() calls on an + // object are followed by an immediately release of the returned + // pointer - see below for details. + Marshal.Release(classNative.ReferenceTracker); + + // .NET 5 limitation + // + // For aggregated scenarios involving IReferenceTracker + // the API handles object cleanup. In .NET 5 the API + // didn't expose an option to handle this so we pass the inner + // in order to handle its lifetime. + // + // The API doesn't handle inner lifetime in any other scenario + // in the .NET 5 timeframe. + instanceToWrap = classNative.Inner; + } + } + + // Create a native object wrapper (i.e. RCW). + // + // Note this function will call QueryInterface() on the supplied instance, + // therefore it is important that the enclosing CCW forwards to its inner + // if aggregation is involved. This is typically accomplished through an + // implementation of ICustomQueryInterface. + cw.GetOrRegisterObjectForComInstance(instanceToWrap, createObjectFlags, thisInstance); + } + + if (isAggregation) + { + // We release the instance here, but continue to use it since + // ownership was transferred to the API and it will guarantee + // the appropriate lifetime. + Marshal.Release(classNative.Instance); + } + else + { + // In non-aggregation scenarios where an inner exists and + // reference tracker is involved, we release the inner. + // + // .NET 5 limitation - see logic above. + if (classNative.Inner != default(IntPtr) && classNative.ReferenceTracker != default(IntPtr)) + { + Marshal.Release(classNative.Inner); + } + } + + // The following describes the valid local values to consider and details + // on their usage during the object's lifetime. + classNative.Release = ReleaseFlags.None; + if (isAggregation) + { + // Aggregation scenarios should avoid calling AddRef() on the + // newInstance value. This is due to the semantics of COM Aggregation + // and the fact that calling an AddRef() on the instance will increment + // the CCW which in turn will ensure it cannot be cleaned up. Calling + // AddRef() on the instance when passed to unmanagec code is correct + // since unmanaged code is required to call Release() at some point. + if (classNative.ReferenceTracker == default(IntPtr)) + { + // COM scenario + // The pointer to dispatch on for the instance. + // ** Never release. + classNative.Release |= ReleaseFlags.None; // Instance + + // A pointer to the inner that should be queried for + // additional interfaces. Immediately after a QueryInterface() + // a Release() should be called on the returned pointer but the + // pointer can be retained and used. + // ** Release in this class's Finalizer. + classNative.Release |= ReleaseFlags.Inner; // Inner + } + else + { + // WinUI scenario + // The pointer to dispatch on for the instance. + // ** Never release. + classNative.Release |= ReleaseFlags.None; // Instance + + // A pointer to the inner that should be queried for + // additional interfaces. Immediately after a QueryInterface() + // a Release() should be called on the returned pointer but the + // pointer can be retained and used. + // ** Never release. + classNative.Release |= ReleaseFlags.None; // Inner + + // No longer needed. + // ** Never release. + classNative.Release |= ReleaseFlags.None; // ReferenceTracker + } + } + else + { + if (classNative.ReferenceTracker == default(IntPtr)) + { + // COM scenario + // The pointer to dispatch on for the instance. + // ** Release in this class's Finalizer. + classNative.Release |= ReleaseFlags.Instance; // Instance + } + else + { + // WinUI scenario + // The pointer to dispatch on for the instance. + // ** Release in this class's Finalizer. + classNative.Release |= ReleaseFlags.Instance; // Instance + + // This instance should be used to tell the + // Reference Tracker runtime whenever an AddRef()/Release() + // is performed on newInstance. + // ** Release in this class's Finalizer. + classNative.Release |= ReleaseFlags.ReferenceTracker; // ReferenceTracker + } + } + } + + public static void Cleanup(ref ClassNative classNative) + { + if (classNative.Release.HasFlag(ReleaseFlags.Inner)) + { + Marshal.Release(classNative.Inner); + } + if (classNative.Release.HasFlag(ReleaseFlags.Instance)) + { + Marshal.Release(classNative.Instance); + } + if (classNative.Release.HasFlag(ReleaseFlags.ReferenceTracker)) + { + Marshal.Release(classNative.ReferenceTracker); + } + } } } diff --git a/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs b/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs index d6f60d4176117..899c4f9c74a8d 100644 --- a/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs +++ b/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs @@ -15,15 +15,15 @@ partial class Program { struct MarshalInterface { - [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint="CreateTrackerObject_SkipTrackerRuntime")] [return: MarshalAs(UnmanagedType.IUnknown)] extern public static object CreateTrackerObjectAsIUnknown(); - [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint="CreateTrackerObject_SkipTrackerRuntime")] [return: MarshalAs(UnmanagedType.Interface)] extern public static FakeWrapper CreateTrackerObjectAsInterface(); - [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint = nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint="CreateTrackerObject_SkipTrackerRuntime")] [return: MarshalAs(UnmanagedType.Interface)] extern public static Test CreateTrackerObjectWrongType(); diff --git a/src/tests/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp b/src/tests/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp index efec0855d05b1..fd4ee2906effa 100644 --- a/src/tests/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp +++ b/src/tests/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp @@ -72,23 +72,45 @@ namespace STDMETHOD(DropObjectRef)(_In_ int id) = 0; }; - struct TrackerObject : public ITrackerObject, public API::IReferenceTracker, public UnknownImpl + struct TrackerObject : public IUnknown, public UnknownImpl { - const size_t _id; - std::atomic _trackerSourceCount; - bool _connected; - std::atomic _elementId; - std::unordered_map> _elements; + static std::atomic AllocationCount; - TrackerObject(size_t id) : _id{ id }, _trackerSourceCount{ 0 }, _connected{ false }, _elementId{ 1 } - { } + static const int32_t DisableTrackedCount = -1; + static const int32_t EnableTrackedCount = 0; + static std::atomic TrackedAllocationCount; - HRESULT ToggleTargets(_In_ bool shouldPeg) + TrackerObject(_In_ size_t id, _In_opt_ IUnknown* pUnkOuter) + : _outer{ pUnkOuter == nullptr ? static_cast(this) : pUnkOuter } + , _impl{ id, _outer } + { + ++AllocationCount; + + if (TrackedAllocationCount != DisableTrackedCount) + ++TrackedAllocationCount; + } + + ~TrackerObject() + { + // There is a cleanup race when tracking is enabled. + // It is possible previously allocated objects could be + // cleaned up during alloc tracking scenarios - these can be + // ignored. + // + // See the locking around the tracking scenarios in the + // managed P/Invoke usage. + if (TrackedAllocationCount > 0) + --TrackedAllocationCount; + + --AllocationCount; + } + + HRESULT TogglePeg(_In_ bool shouldPeg) { HRESULT hr; - auto curr = std::begin(_elements); - while (curr != std::end(_elements)) + auto curr = std::begin(_impl._elements); + while (curr != std::end(_impl._elements)) { ComSmartPtr mowMaybe; if (S_OK == curr->second->QueryInterface(&mowMaybe)) @@ -105,69 +127,189 @@ namespace ++curr; } + // Handle the case for aggregation + // + // Pegging occurs during a GC. We can't QI for this during + // a GC because the COM scenario would fallback to + // ICustomQueryInterface (i.e. managed code). + if (_impl._outerRefTrackerTarget) + { + ComSmartPtr thisTgtMaybe; + if (S_OK == _outer->QueryInterface(&thisTgtMaybe)) + { + if (shouldPeg) + { + RETURN_IF_FAILED(thisTgtMaybe->Peg()); + } + else + { + RETURN_IF_FAILED(thisTgtMaybe->Unpeg()); + } + } + } + return S_OK; } - STDMETHOD(AddObjectRef)(_In_ IUnknown* c, _Out_ int* id) + HRESULT DisconnectFromReferenceTrackerRuntime() { - assert(c != nullptr && id != nullptr); + HRESULT hr; - try - { - *id = _elementId; - if (!_elements.insert(std::make_pair(*id, ComSmartPtr{ c })).second) - return S_FALSE; + RETURN_IF_FAILED(TogglePeg(/* should peg */ false)); - _elementId++; - } - catch (const std::bad_alloc&) + // Handle the case for aggregation in the release case. + if (_impl._outerRefTrackerTarget) { - return E_OUTOFMEMORY; + ComSmartPtr thisTgtMaybe; + if (S_OK == _outer->QueryInterface(&thisTgtMaybe)) + RETURN_IF_FAILED(thisTgtMaybe->ReleaseFromReferenceTracker()); } - ComSmartPtr mowMaybe; - if (S_OK == c->QueryInterface(&mowMaybe)) - (void)mowMaybe->AddRefFromReferenceTracker(); - return S_OK; } - STDMETHOD(DropObjectRef)(_In_ int id) + struct TrackerObjectImpl : public ITrackerObject, public API::IReferenceTracker { - auto iter = _elements.find(id); - if (iter == std::end(_elements)) - return S_FALSE; + IUnknown* _implOuter; + bool _outerRefTrackerTarget; + const size_t _id; + std::atomic _trackerSourceCount; + bool _connected; + std::atomic _elementId; + std::unordered_map> _elements; + + TrackerObjectImpl(_In_ size_t id, _In_ IUnknown* pUnkOuter) + : _implOuter{ pUnkOuter } + , _outerRefTrackerTarget{ false } + , _id{ id } + , _trackerSourceCount{ 0 } + , _connected{ false } + , _elementId{ 1 } + { + // Check if we are aggregating with a tracker target + ComSmartPtr tgt; + if (SUCCEEDED(_implOuter->QueryInterface(&tgt))) + { + _outerRefTrackerTarget = true; + (void)tgt->AddRefFromReferenceTracker(); + if (FAILED(tgt->Peg())) + { + throw std::exception{ "Peg failure" }; + } + } + } - ComSmartPtr mowMaybe; - if (S_OK == iter->second->QueryInterface(&mowMaybe)) + STDMETHOD(AddObjectRef)(_In_ IUnknown* c, _Out_ int* id) { - (void)mowMaybe->ReleaseFromReferenceTracker(); - (void)mowMaybe->Unpeg(); + assert(c != nullptr && id != nullptr); + + try + { + *id = _elementId; + if (!_elements.insert(std::make_pair(*id, ComSmartPtr{ c })).second) + return S_FALSE; + + _elementId++; + } + catch (const std::bad_alloc&) + { + return E_OUTOFMEMORY; + } + + ComSmartPtr mowMaybe; + if (S_OK == c->QueryInterface(&mowMaybe)) + (void)mowMaybe->AddRefFromReferenceTracker(); + + return S_OK; } - _elements.erase(iter); + STDMETHOD(DropObjectRef)(_In_ int id) + { + auto iter = _elements.find(id); + if (iter == std::end(_elements)) + return S_FALSE; - return S_OK; - } + ComSmartPtr mowMaybe; + if (S_OK == iter->second->QueryInterface(&mowMaybe)) + { + (void)mowMaybe->ReleaseFromReferenceTracker(); + (void)mowMaybe->Unpeg(); + } + + _elements.erase(iter); - STDMETHOD(ConnectFromTrackerSource)(); - STDMETHOD(DisconnectFromTrackerSource)(); - STDMETHOD(FindTrackerTargets)(_In_ API::IFindReferenceTargetsCallback* pCallback); - STDMETHOD(GetReferenceTrackerManager)(_Outptr_ API::IReferenceTrackerManager** ppTrackerManager); - STDMETHOD(AddRefFromTrackerSource)(); - STDMETHOD(ReleaseFromTrackerSource)(); - STDMETHOD(PegFromTrackerSource)(); + return S_OK; + } + + STDMETHOD(ConnectFromTrackerSource)(); + STDMETHOD(DisconnectFromTrackerSource)(); + STDMETHOD(FindTrackerTargets)(_In_ API::IFindReferenceTargetsCallback* pCallback); + STDMETHOD(GetReferenceTrackerManager)(_Outptr_ API::IReferenceTrackerManager** ppTrackerManager); + STDMETHOD(AddRefFromTrackerSource)(); + STDMETHOD(ReleaseFromTrackerSource)(); + STDMETHOD(PegFromTrackerSource)(); + + STDMETHOD(QueryInterface)( + /* [in] */ REFIID riid, + /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR* __RPC_FAR* ppvObject) + { + return _implOuter->QueryInterface(riid, ppvObject); + } + STDMETHOD_(ULONG, AddRef)(void) + { + return _implOuter->AddRef(); + } + STDMETHOD_(ULONG, Release)(void) + { + return _implOuter->Release(); + } + }; STDMETHOD(QueryInterface)( /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR* __RPC_FAR* ppvObject) { - return DoQueryInterface(riid, ppvObject, static_cast(this), static_cast(this)); + if (ppvObject == nullptr) + return E_POINTER; + + IUnknown* tgt; + + // Aggregation implementation. + if (riid == IID_IUnknown) + { + tgt = static_cast(this); + } + else + { + // Send non-IUnknown queries to the implementation. + if (riid == __uuidof(API::IReferenceTracker)) + { + tgt = static_cast(&_impl); + } + else if (riid == __uuidof(ITrackerObject)) + { + tgt = static_cast(&_impl); + } + else + { + *ppvObject = nullptr; + return E_NOINTERFACE; + } + } + + (void)tgt->AddRef(); + *ppvObject = tgt; + return S_OK; } - DEFINE_REF_COUNTING() + DEFINE_REF_COUNTING(); + + IUnknown* _outer; + TrackerObjectImpl _impl; }; + std::atomic TrackerObject::AllocationCount{}; + std::atomic TrackerObject::TrackedAllocationCount{ TrackerObject::DisableTrackedCount }; std::atomic CurrentObjectId{}; class TrackerRuntimeManagerImpl : public API::IReferenceTrackerManager @@ -176,16 +318,29 @@ namespace std::list> _objects; public: - void RecordObject(_In_ TrackerObject* obj) + ITrackerObject* RecordObject(_In_ TrackerObject* obj, _Outptr_ IUnknown** inner) { _objects.push_back(ComSmartPtr{ obj }); if (_runtimeServices != nullptr) _runtimeServices->AddMemoryPressure(sizeof(TrackerObject)); + + // Perform a QI to get the proper identity. + (void)obj->QueryInterface(IID_IUnknown, (void**)inner); + + // Get the default interface. + ITrackerObject* type; + (void)obj->QueryInterface(__uuidof(ITrackerObject), (void**)&type); + + return type; } void ReleaseObjects() { + // Unpeg all instances + for (auto& i : _objects) + (void)i->DisconnectFromReferenceTrackerRuntime(); + size_t count = _objects.size(); _objects.clear(); if (_runtimeServices != nullptr) @@ -205,7 +360,7 @@ namespace { // Unpeg all instances for (auto& i : _objects) - i->ToggleTargets(/* should peg */ false); + i->TogglePeg(/* should peg */ false); return S_OK; } @@ -214,7 +369,7 @@ namespace { // Verify and ensure all connected types are pegged for (auto& i : _objects) - i->ToggleTargets(/* should peg */ true); + i->TogglePeg(/* should peg */ true); return S_OK; } @@ -262,19 +417,19 @@ namespace TrackerRuntimeManagerImpl TrackerRuntimeManager; - HRESULT STDMETHODCALLTYPE TrackerObject::ConnectFromTrackerSource() + HRESULT STDMETHODCALLTYPE TrackerObject::TrackerObjectImpl::ConnectFromTrackerSource() { _connected = true; return S_OK; } - HRESULT STDMETHODCALLTYPE TrackerObject::DisconnectFromTrackerSource() + HRESULT STDMETHODCALLTYPE TrackerObject::TrackerObjectImpl::DisconnectFromTrackerSource() { _connected = false; return S_OK; } - HRESULT STDMETHODCALLTYPE TrackerObject::FindTrackerTargets(_In_ API::IFindReferenceTargetsCallback* pCallback) + HRESULT STDMETHODCALLTYPE TrackerObject::TrackerObjectImpl::FindTrackerTargets(_In_ API::IFindReferenceTargetsCallback* pCallback) { assert(pCallback != nullptr); @@ -291,27 +446,27 @@ namespace return S_OK; } - HRESULT STDMETHODCALLTYPE TrackerObject::GetReferenceTrackerManager(_Outptr_ API::IReferenceTrackerManager** ppTrackerManager) + HRESULT STDMETHODCALLTYPE TrackerObject::TrackerObjectImpl::GetReferenceTrackerManager(_Outptr_ API::IReferenceTrackerManager** ppTrackerManager) { assert(ppTrackerManager != nullptr); return TrackerRuntimeManager.QueryInterface(__uuidof(API::IReferenceTrackerManager), (void**)ppTrackerManager); } - HRESULT STDMETHODCALLTYPE TrackerObject::AddRefFromTrackerSource() + HRESULT STDMETHODCALLTYPE TrackerObject::TrackerObjectImpl::AddRefFromTrackerSource() { assert(0 <= _trackerSourceCount); ++_trackerSourceCount; return S_OK; } - HRESULT STDMETHODCALLTYPE TrackerObject::ReleaseFromTrackerSource() + HRESULT STDMETHODCALLTYPE TrackerObject::TrackerObjectImpl::ReleaseFromTrackerSource() { assert(0 < _trackerSourceCount); --_trackerSourceCount; return S_OK; } - HRESULT STDMETHODCALLTYPE TrackerObject::PegFromTrackerSource() + HRESULT STDMETHODCALLTYPE TrackerObject::TrackerObjectImpl::PegFromTrackerSource() { /* Not used by runtime */ return E_NOTIMPL; @@ -319,13 +474,30 @@ namespace } // Create external object -extern "C" DLL_EXPORT ITrackerObject* STDMETHODCALLTYPE CreateTrackerObject() +extern "C" DLL_EXPORT ITrackerObject * STDMETHODCALLTYPE CreateTrackerObject_SkipTrackerRuntime() +{ + auto obj = new TrackerObject{ static_cast(-1), nullptr }; + return &obj->_impl; +} + +extern "C" DLL_EXPORT ITrackerObject* STDMETHODCALLTYPE CreateTrackerObject_Unsafe(_In_opt_ IUnknown* outer, _Outptr_ IUnknown** inner) { - auto obj = new TrackerObject{ CurrentObjectId++ }; + ComSmartPtr obj; + obj.Attach(new TrackerObject{ CurrentObjectId++, outer }); - TrackerRuntimeManager.RecordObject(obj); + return TrackerRuntimeManager.RecordObject(obj, inner); +} - return obj; +extern "C" DLL_EXPORT void STDMETHODCALLTYPE StartTrackerObjectAllocationCount_Unsafe() +{ + TrackerObject::TrackedAllocationCount = TrackerObject::EnableTrackedCount; +} + +extern "C" DLL_EXPORT int32_t STDMETHODCALLTYPE StopTrackerObjectAllocationCount_Unsafe() +{ + int32_t count = TrackerObject::TrackedAllocationCount; + TrackerObject::TrackedAllocationCount = TrackerObject::DisableTrackedCount; + return count; } // Release the reference on all internally held tracker objects @@ -346,7 +518,7 @@ extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsIUnknown(IUnknown HRESULT hr; ComSmartPtr testObj; - RETURN_IF_FAILED(obj->QueryInterface(&testObj)) + RETURN_IF_FAILED(obj->QueryInterface(&testObj)); RETURN_IF_FAILED(testObj->SetValue(i)); *out = testObj.Detach();