Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update UnmanagedCallersOnlyAttribute API surface #37843

Merged
merged 8 commits into from
Jun 23, 2020
2 changes: 1 addition & 1 deletion src/coreclr/src/vm/customattribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class Attribute
CaNamedArgArrayREF* ppCustomAttributeNamedArguments,
AssemblyBaseObject* pAssemblyUNSAFE);

private:
static HRESULT ParseAttributeArgumentValues(
void* pCa,
INT32 cCa,
Expand All @@ -116,6 +115,7 @@ class Attribute
COUNT_T cNamedArgs,
DomainAssembly* pDomainAssembly);

private:
static HRESULT ParseCaValue(
CustomAttributeParser &ca,
CaValue* pCaArg,
Expand Down
94 changes: 78 additions & 16 deletions src/coreclr/src/vm/dllimportcallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "appdomain.inl"
#include "callingconvention.h"
#include "customattribute.h"
#include "typeparse.h"

#ifndef CROSSGEN_COMPILE

Expand Down Expand Up @@ -626,32 +627,93 @@ VOID UMThunkMarshInfo::SetUpForUnmanagedCallersOnly()

BYTE* pData = NULL;
LONG cData = 0;
CorPinvokeMap callConv = (CorPinvokeMap)0;

bool nativeCallableInternalData = false;
HRESULT hr = pMD->GetCustomAttribute(WellKnownAttribute::UnmanagedCallersOnly, (const VOID **)(&pData), (ULONG *)&cData);
if (hr == S_FALSE)
{
hr = pMD->GetCustomAttribute(WellKnownAttribute::NativeCallableInternal, (const VOID **)(&pData), (ULONG *)&cData);
nativeCallableInternalData = SUCCEEDED(hr);
}

IfFailThrow(hr);

_ASSERTE(cData > 0);

CustomAttributeParser ca(pData, cData);

// UnmanagedCallersOnly has two optional named arguments CallingConvention and EntryPoint.
CaNamedArg namedArgs[2];
CaTypeCtor caType(SERIALIZATION_TYPE_STRING);
// First, the void constructor.
IfFailThrow(ParseKnownCaArgs(ca, NULL, 0));

// Now the optional named properties
namedArgs[0].InitI4FieldEnum("CallingConvention", "System.Runtime.InteropServices.CallingConvention", (ULONG)callConv);
namedArgs[1].Init("EntryPoint", SERIALIZATION_TYPE_STRING, caType);
IfFailThrow(ParseKnownCaNamedArgs(ca, namedArgs, lengthof(namedArgs)));

callConv = (CorPinvokeMap)(namedArgs[0].val.u4 << 8);
// Let UMThunkMarshalInfo choose the default if calling convension not definied.
if (namedArgs[0].val.type.tag != SERIALIZATION_TYPE_UNDEFINED)
m_callConv = (UINT16)callConv;

// For the UnmanagedCallersOnly scenario.
CaType caCallConvs;

// Define the optional named properties
if (nativeCallableInternalData)
{
namedArgs[0].InitI4FieldEnum("CallingConvention", "System.Runtime.InteropServices.CallingConvention", (ULONG)(CorPinvokeMap)0);
}
else
{
caCallConvs.Init(SERIALIZATION_TYPE_SZARRAY, SERIALIZATION_TYPE_TYPE, SERIALIZATION_TYPE_UNDEFINED, NULL, 0);
namedArgs[0].Init("CallConvs", SERIALIZATION_TYPE_SZARRAY, caCallConvs);
}

CaTypeCtor caEntryPoint(SERIALIZATION_TYPE_STRING);
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
namedArgs[1].Init("EntryPoint", SERIALIZATION_TYPE_STRING, caEntryPoint);

InlineFactory<SArray<CaValue>, 4> caValueArrayFactory;
DomainAssembly* domainAssembly = pMD->GetLoaderModule()->GetDomainAssembly();
IfFailThrow(Attribute::ParseAttributeArgumentValues(
pData,
cData,
&caValueArrayFactory,
NULL,
0,
namedArgs,
lengthof(namedArgs),
domainAssembly));

// If the value isn't defined, then return without setting anything.
if (namedArgs[0].val.type.tag == SERIALIZATION_TYPE_UNDEFINED)
return;

CorPinvokeMap callConvLocal = (CorPinvokeMap)0;
if (nativeCallableInternalData)
{
callConvLocal = (CorPinvokeMap)(namedArgs[0].val.u4 << 8);
}
else
{
// Set WinAPI as the default
callConvLocal = CorPinvokeMap::pmCallConvWinapi;

CaValue* arrayOfTypes = &namedArgs[0].val;
for (ULONG i = 0; i < arrayOfTypes->arr.length; i++)
{
CaValue& typeNameValue = arrayOfTypes->arr[i];

StackSString typeFQName(SString::Utf8, typeNameValue.str.pStr, typeNameValue.str.cbStr);
if (typeFQName.BeginsWith(W("System.Runtime.CompilerServices.CallConvCdecl")))
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
{
callConvLocal = CorPinvokeMap::pmCallConvCdecl;
}
else if (typeFQName.BeginsWith(W("System.Runtime.CompilerServices.CallConvStdcall")))
{
callConvLocal = CorPinvokeMap::pmCallConvStdcall;
}
else if (typeFQName.BeginsWith(W("System.Runtime.CompilerServices.CallConvFastcall")))
{
callConvLocal = CorPinvokeMap::pmCallConvFastcall;
}
else if (typeFQName.BeginsWith(W("System.Runtime.CompilerServices.CallConvThiscall")))
{
callConvLocal = CorPinvokeMap::pmCallConvThiscall;
}
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
}
}

m_callConv = (UINT16)callConvLocal;
}

// Compiles an unmanaged to managed thunk for the given signature.
Expand Down Expand Up @@ -776,7 +838,7 @@ Stub *UMThunkMarshInfo::CompileNExportThunk(LoaderHeap *pLoaderHeap, PInvokeStat
stubInfo.m_cbSrcStack = static_cast<UINT16>(m_cbActualArgSize);
stubInfo.m_cbDstStack = nStackBytes;

if (pSigInfo->GetCallConv() == pmCallConvCdecl)
if (m_callConv == pmCallConvCdecl)
{
// caller pop
m_cbRetPop = 0;
Expand All @@ -786,7 +848,7 @@ Stub *UMThunkMarshInfo::CompileNExportThunk(LoaderHeap *pLoaderHeap, PInvokeStat
// callee pop
m_cbRetPop = static_cast<UINT16>(m_cbActualArgSize);

if (pSigInfo->GetCallConv() == pmCallConvThiscall)
if (m_callConv == pmCallConvThiscall)
{
stubInfo.m_wFlags |= umtmlThisCall;
if (argit.HasRetBuffArg())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ .locals init (native int V_0)
testNativeMethod();
}

[UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
public static int CallbackViaUnmanagedCalli(int val)
{
return DoubleImpl(val);
Expand Down Expand Up @@ -587,7 +587,7 @@ .locals init (native int V_0)
Assert.AreEqual(expected, testNativeMethod());
}

[UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
public static int CallbackViaUnmanagedCalliThrows(int val)
{
throw new Exception() { HResult = CallbackThrowsErrorCode };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ private struct EnumData
}

// EnumCalendarInfoExEx callback itself.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumCalendarInfoCallback(char* lpCalendarInfoString, uint calendar, IntPtr pReserved, void* lParam)
{
ref EnumData context = ref Unsafe.As<byte, EnumData>(ref *(byte*)lParam);
Expand Down Expand Up @@ -425,7 +425,7 @@ public struct NlsEnumCalendarsData
public List<int> calendars; // list of calendars found so far
}

// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumCalendarsCallback(char* lpCalendarInfoString, uint calendar, IntPtr reserved, void* lParam)
{
ref NlsEnumCalendarsData context = ref Unsafe.As<byte, NlsEnumCalendarsData>(ref *(byte*)lParam);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ private struct EnumLocaleData
}

// EnumSystemLocaleEx callback.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumSystemLocalesProc(char* lpLocaleString, uint flags, void* contextHandle)
{
ref EnumLocaleData context = ref Unsafe.As<byte, EnumLocaleData>(ref *(byte*)contextHandle);
Expand All @@ -382,7 +382,7 @@ private static unsafe Interop.BOOL EnumSystemLocalesProc(char* lpLocaleString, u
}

// EnumSystemLocaleEx callback.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumAllSystemLocalesProc(char* lpLocaleString, uint flags, void* contextHandle)
{
ref EnumData context = ref Unsafe.As<byte, EnumData>(ref *(byte*)contextHandle);
Expand All @@ -404,7 +404,7 @@ private struct EnumData
}

// EnumTimeFormatsEx callback itself.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumTimeCallback(char* lpTimeFormatString, void* lParam)
{
ref EnumData context = ref Unsafe.As<byte, EnumData>(ref *(byte*)lParam);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace System.Runtime.InteropServices
{
// Used for the CallingConvention named argument to the DllImport and UnmanagedCallersOnly attribute
// Used for the CallingConvention named argument to the DllImport attribute
public enum CallingConvention
{
Winapi = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ public UnmanagedCallersOnlyAttribute()
/// <summary>
/// Optional. If omitted, the runtime will use the default platform calling convention.
/// </summary>
public CallingConvention CallingConvention;
/// <remarks>
/// Supplied types must be from the official "System.Runtime.CompilerServices" namespace and
/// be of the form "CallConvXXX".
/// </remarks>
public Type[]? CallConvs;

/// <summary>
/// Optional. If omitted, no named export is emitted during compilation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ public static void RegisterForMarshalling(ComWrappers instance) { }
public sealed class UnmanagedCallersOnlyAttribute : System.Attribute
{
public UnmanagedCallersOnlyAttribute() { }
public System.Runtime.InteropServices.CallingConvention CallingConvention;
public System.Type[]? CallConvs;
public string? EntryPoint;
}
}
Expand Down