Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special casing System.Guid for COM VARIANT marshalling #100377

Merged
merged 11 commits into from
Apr 5, 2024
33 changes: 25 additions & 8 deletions src/coreclr/vm/olevariant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2567,17 +2567,34 @@ void OleVariant::MarshalRecordVariantOleToCom(VARIANT *pOleVariant,
if (!pRecInfo)
COMPlusThrow(kArgumentException, IDS_EE_INVALID_OLE_VARIANT);

LPVOID pvRecord = V_RECORD(pOleVariant);
if (pvRecord == NULL)
{
pComVariant->SetObjRef(NULL);
return;
}

MethodTable* pValueClass = NULL;
{
GCX_PREEMP();
pValueClass = GetMethodTableForRecordInfo(pRecInfo);
}

if (pValueClass == NULL)
{
// This value type should have been registered through
// a TLB. CoreCLR doesn't support dynamic type mapping.
COMPlusThrow(kArgumentException, IDS_EE_CANNOT_MAP_TO_MANAGED_VC);
}
_ASSERTE(pValueClass->IsBlittable());

OBJECTREF BoxedValueClass = NULL;
GCPROTECT_BEGIN(BoxedValueClass)
{
LPVOID pvRecord = V_RECORD(pOleVariant);
if (pvRecord)
{
// This value type should have been registered through
// a TLB. CoreCLR doesn't support dynamic type mapping.
COMPlusThrow(kArgumentException, IDS_EE_CANNOT_MAP_TO_MANAGED_VC);
}

// Now that we have a blittable value class, allocate an instance of the
// boxed value class and copy the contents of the record into it.
BoxedValueClass = AllocateObject(pValueClass);
memcpyNoGCRefs(BoxedValueClass->GetData(), (BYTE*)pvRecord, pValueClass->GetNativeSize());
pComVariant->SetObjRef(BoxedValueClass);
}
GCPROTECT_END();
Expand Down
98 changes: 98 additions & 0 deletions src/coreclr/vm/stdinterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,44 @@ HRESULT GetITypeLibForAssembly(_In_ Assembly *pAssembly, _Outptr_ ITypeLib **ppT
return S_OK;
} // HRESULT GetITypeLibForAssembly()

// .NET Frameworks' mscorlib TLB GUID.
static const GUID s_MscorlibGuid = { 0xBED7F4EA, 0x1A96, 0x11D2, { 0x8F, 0x08, 0x00, 0xA0, 0xC9, 0xA6, 0x18, 0x6D } };

// Hard-coded GUID for System.Guid.
static const GUID s_GuidForSystemGuid = { 0x9C5923E9, 0xDE52, 0x33EA, { 0x88, 0xDE, 0x7E, 0xBC, 0x86, 0x33, 0xB9, 0xCC } };

// There are types that are helpful to provide that facilitate porting from
// .NET Framework to .NET 8+. This function is used to acquire their ITypeInfo.
// This should be used narrowly. Types at a minimum should be blittable.
static bool TryDeferToMscorlib(MethodTable* pClass, ITypeInfo** ppTI)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_PREEMPTIVE;
PRECONDITION(pClass != NULL);
PRECONDITION(pClass->IsBlittable());
PRECONDITION(ppTI != NULL);
}
CONTRACTL_END;

// Marshalling of a System.Guid is such a common scenario, let's see if we can load
// the .NET Framework's TLB. This is a niche scenario, but one that impacts many teams
// porting code to .NET 8+.
if (pClass == CoreLibBinder::GetClass(CLASS__GUID))
{
SafeComHolder<ITypeLib> pMscorlibTypeLib = NULL;
if (SUCCEEDED(::LoadRegTypeLib(s_MscorlibGuid, 2, 4, 0, &pMscorlibTypeLib)))
{
if (SUCCEEDED(pMscorlibTypeLib->GetTypeInfoOfGuid(s_GuidForSystemGuid, ppTI)))
return true;
}
}

return false;
}

HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClassInfo)
{
CONTRACTL
Expand All @@ -625,6 +663,7 @@ HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClas
GUID clsid;
GUID ciid;
ComMethodTable *pComMT = NULL;
MethodTable* pOriginalClass = pClass;
HRESULT hr = S_OK;
SafeComHolder<ITypeLib> pITLB = NULL;
SafeComHolder<ITypeInfo> pTI = NULL;
Expand Down Expand Up @@ -770,12 +809,71 @@ HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClas
{
if (!FAILED(hr))
hr = E_FAIL;

if (pOriginalClass->IsValueType() && pOriginalClass->IsBlittable())
{
if (TryDeferToMscorlib(pOriginalClass, ppTI))
hr = S_OK;
}
}

ReturnHR:
return hr;
} // HRESULT GetITypeInfoForEEClass()

MethodTable* GetMethodTableForRecordInfo(IRecordInfo* recInfo)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_PREEMPTIVE;
PRECONDITION(recInfo != NULL);
}
CONTRACTL_END;

HRESULT hr;

//
// Only a narrow set of types are supported.
// See TryDeferToMscorlib() above.
//

// Verify the associated TypeLib attribute
SafeComHolder<ITypeInfo> typeInfo;
hr = recInfo->GetTypeInfo(&typeInfo);
if (FAILED(hr))
return NULL;

SafeComHolder<ITypeLib> typeLib;
UINT index;
hr = typeInfo->GetContainingTypeLib(&typeLib, &index);
if (FAILED(hr))
return NULL;

TLIBATTR* attrs;
hr = typeLib->GetLibAttr(&attrs);
if (FAILED(hr))
return NULL;

GUID libGuid = attrs->guid;
typeLib->ReleaseTLibAttr(attrs);
if (s_MscorlibGuid != libGuid)
return NULL;

// Verify the Guid of the associated type
GUID typeGuid;
hr = recInfo->GetGuid(&typeGuid);
if (FAILED(hr))
return NULL;

// Check for supported types.
if (s_GuidForSystemGuid == typeGuid)
return CoreLibBinder::GetClass(CLASS__GUID);

return NULL;
}

// Returns a NON-ADDREF'd ITypeInfo.
HRESULT GetITypeInfoForMT(ComMethodTable *pMT, ITypeInfo **ppTI)
{
Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/vm/stdinterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,7 @@ IErrorInfo *GetSupportedErrorInfo(IUnknown *iface, REFIID riid);
// Helpers to get the ITypeInfo* for a type.
HRESULT GetITypeInfoForEEClass(MethodTable *pMT, ITypeInfo **ppTI, bool bClassInfo = false);

// Gets the MethodTable for the associated IRecordInfo.
MethodTable* GetMethodTableForRecordInfo(IRecordInfo* recInfo);

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,38 @@ public void GetObjectForNativeVariant_InvalidDate_ThrowsArgumentException(double
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public void GetObjectForNativeVariant_NoDataForRecord_ThrowsArgumentException()
public void GetObjectForNativeVariant_NoRecordInfo_ThrowsArgumentException()
{
Variant variant = CreateVariant(VT_RECORD, new UnionTypes { _record = new Record { _recordInfo = IntPtr.Zero } });
AssertExtensions.Throws<ArgumentException>(null, () => GetObjectForNativeVariant(variant));
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public void GetObjectForNativeVariant_NoRecordData_ReturnsNull()
{
var recordInfo = new RecordInfo();
IntPtr pRecordInfo = Marshal.GetComInterfaceForObject<RecordInfo, IRecordInfo>(recordInfo);
try
{
Variant variant = CreateVariant(VT_RECORD, new UnionTypes
{
_record = new Record
{
_record = IntPtr.Zero,
_recordInfo = pRecordInfo
}
});
Assert.Null(GetObjectForNativeVariant(variant));
}
finally
{
Marshal.Release(pRecordInfo);
}
}

public static IEnumerable<object[]> GetObjectForNativeVariant_NoSuchGuid_TestData()
{
yield return new object[] { typeof(object).GUID };
yield return new object[] { typeof(string).GUID };
yield return new object[] { Guid.Empty };
}
Expand Down
1 change: 1 addition & 0 deletions src/tests/Interop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ if(CLR_CMAKE_TARGET_WIN32)
add_subdirectory(COM/NativeClients/DefaultInterfaces)
add_subdirectory(COM/NativeClients/Dispatch)
add_subdirectory(COM/NativeClients/Events)
add_subdirectory(COM/NativeClients/MiscTypes)
add_subdirectory(COM/ComWrappers/MockReferenceTrackerRuntime)
add_subdirectory(COM/ComWrappers/WeakReference)

Expand Down
18 changes: 18 additions & 0 deletions src/tests/Interop/COM/NETClients/MiscTypes/App.manifest
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?xml version="1.0" encoding="utf-8"?>
<assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1">
<assemblyIdentity
type="win32"
name="NetClientMiscTypes"
version="1.0.0.0" />

<dependency>
<dependentAssembly>
<!-- RegFree COM -->
<assemblyIdentity
type="win32"
name="COMNativeServer.X"
version="1.0.0.0"/>
</dependentAssembly>
</dependency>

</assembly>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<!-- Needed for CMakeProjectReference, GC.WaitForPendingFinalizers -->
<RequiresProcessIsolation>true</RequiresProcessIsolation>
<ApplicationManifest>App.manifest</ApplicationManifest>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<Compile Include="Program.cs" />
<Compile Include="../../ServerContracts/Server.CoClasses.cs" />
<Compile Include="../../ServerContracts/Server.Contracts.cs" />
<Compile Include="../../ServerContracts/ServerGuids.cs" />
</ItemGroup>
<ItemGroup>
<CMakeProjectReference Include="../../NativeServer/CMakeLists.txt" />
<ProjectReference Include="$(TestLibraryProjectPath)" />
</ItemGroup>
</Project>
90 changes: 90 additions & 0 deletions src/tests/Interop/COM/NETClients/MiscTypes/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Xunit;
namespace NetClient
{
using System;
using System.Runtime.InteropServices;

using TestLibrary;
using Xunit;
using Server.Contract;
using Server.Contract.Servers;

public unsafe class Program
{
[Fact]
public static int TestEntryPoint()
{
// RegFree COM is not supported on Windows Nano
if (TestLibrary.Utilities.IsWindowsNanoServer)
{
return 100;
}

try
{
RunTests();
}
catch (Exception e)
{
Console.WriteLine($"Test object interop failure: {e}");
return 101;
}

return 100;
}

private static void RunTests()
{
var miscTypeTesting = (Server.Contract.Servers.MiscTypesTesting)new Server.Contract.Servers.MiscTypesTestingClass();

Console.WriteLine("Validate Primitives <=> VARIANT...");
{
object expected = null;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
{
var expected = DBNull.Value;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
{
var expected = (sbyte)0x0f;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
{
var expected = (short)0x07ff;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
{
var expected = (int)0x07ffffff;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
{
var expected = (long)0x07ffffffffffffff;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
{
var expected = true;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
{
var expected = false;
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}

Console.WriteLine("Validate BSTR <=> VARIANT...");
{
var expected = "The quick Fox jumped over the lazy Dog.";
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}

Console.WriteLine("Validate System.Guid <=> VARIANT...");
{
var expected = new Guid("{8EFAD956-B33D-46CB-90F4-45F55BA68A96}");
Assert.Equal(expected, miscTypeTesting.Marshal_Variant(expected));
}
}
}
}
Loading
Loading