Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update UnmanagedCallersOnlyAttribute API surface #37843

Merged
merged 8 commits into from
Jun 23, 2020
2 changes: 1 addition & 1 deletion src/coreclr/src/vm/customattribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class Attribute
CaNamedArgArrayREF* ppCustomAttributeNamedArguments,
AssemblyBaseObject* pAssemblyUNSAFE);

private:
static HRESULT ParseAttributeArgumentValues(
void* pCa,
INT32 cCa,
Expand All @@ -116,6 +115,7 @@ class Attribute
COUNT_T cNamedArgs,
DomainAssembly* pDomainAssembly);

private:
static HRESULT ParseCaValue(
CustomAttributeParser &ca,
CaValue* pCaArg,
Expand Down
117 changes: 99 additions & 18 deletions src/coreclr/src/vm/dllimportcallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "appdomain.inl"
#include "callingconvention.h"
#include "customattribute.h"
#include "typeparse.h"

#ifndef CROSSGEN_COMPILE

Expand Down Expand Up @@ -614,44 +615,124 @@ VOID UMEntryThunk::CompileUMThunkWorker(UMThunkStubInfo *pInfo,
pcpusl->X86EmitNearJump(pEnableRejoin);
}

namespace
{
// Templated function to compute if a char string begins with a constant string.
template<size_t S2LEN>
bool BeginsWith(ULONG s1Len, const char* s1, const char (&s2)[S2LEN])
{
WRAPPER_NO_CONTRACT;

ULONG s2Len = (ULONG)S2LEN - 1; // Remove null
ULONG minLen = s1Len < s2Len ? s1Len : s2Len;
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
return (0 == strncmp(s1, s2, minLen));
}
}

VOID UMThunkMarshInfo::SetUpForUnmanagedCallersOnly()
{
STANDARD_VM_CONTRACT;

MethodDesc* pMD = GetMethod();
_ASSERTE(pMD != NULL && pMD->HasUnmanagedCallersOnlyAttribute());

// Validate UnmanagedCallersOnlyAttribute usage
// Validate usage
COMDelegate::ThrowIfInvalidUnmanagedCallersOnlyUsage(pMD);

BYTE* pData = NULL;
LONG cData = 0;
CorPinvokeMap callConv = (CorPinvokeMap)0;

bool nativeCallableInternalData = false;
HRESULT hr = pMD->GetCustomAttribute(WellKnownAttribute::UnmanagedCallersOnly, (const VOID **)(&pData), (ULONG *)&cData);
if (hr == S_FALSE)
{
hr = pMD->GetCustomAttribute(WellKnownAttribute::NativeCallableInternal, (const VOID **)(&pData), (ULONG *)&cData);
nativeCallableInternalData = SUCCEEDED(hr);
}

IfFailThrow(hr);

_ASSERTE(cData > 0);

CustomAttributeParser ca(pData, cData);
// UnmanagedCallersOnly has two optional named arguments CallingConvention and EntryPoint.

// UnmanagedCallersOnly and NativeCallableInternal each
// have optional named arguments.
CaNamedArg namedArgs[2];
CaTypeCtor caType(SERIALIZATION_TYPE_STRING);
// First, the void constructor.
IfFailThrow(ParseKnownCaArgs(ca, NULL, 0));

// Now the optional named properties
namedArgs[0].InitI4FieldEnum("CallingConvention", "System.Runtime.InteropServices.CallingConvention", (ULONG)callConv);
namedArgs[1].Init("EntryPoint", SERIALIZATION_TYPE_STRING, caType);
IfFailThrow(ParseKnownCaNamedArgs(ca, namedArgs, lengthof(namedArgs)));

callConv = (CorPinvokeMap)(namedArgs[0].val.u4 << 8);
// Let UMThunkMarshalInfo choose the default if calling convension not definied.
if (namedArgs[0].val.type.tag != SERIALIZATION_TYPE_UNDEFINED)
m_callConv = (UINT16)callConv;

// For the UnmanagedCallersOnly scenario.
CaType caCallConvs;

// Define attribute specific optional named properties
if (nativeCallableInternalData)
{
namedArgs[0].InitI4FieldEnum("CallingConvention", "System.Runtime.InteropServices.CallingConvention", (ULONG)(CorPinvokeMap)0);
}
else
{
caCallConvs.Init(SERIALIZATION_TYPE_SZARRAY, SERIALIZATION_TYPE_TYPE, SERIALIZATION_TYPE_UNDEFINED, NULL, 0);
namedArgs[0].Init("CallConvs", SERIALIZATION_TYPE_SZARRAY, caCallConvs);
}

// Define common optional named properties
CaTypeCtor caEntryPoint(SERIALIZATION_TYPE_STRING);
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
namedArgs[1].Init("EntryPoint", SERIALIZATION_TYPE_STRING, caEntryPoint);

InlineFactory<SArray<CaValue>, 4> caValueArrayFactory;
DomainAssembly* domainAssembly = pMD->GetLoaderModule()->GetDomainAssembly();
IfFailThrow(Attribute::ParseAttributeArgumentValues(
pData,
cData,
&caValueArrayFactory,
NULL,
0,
namedArgs,
lengthof(namedArgs),
domainAssembly));

// If the value isn't defined, then return without setting anything.
if (namedArgs[0].val.type.tag == SERIALIZATION_TYPE_UNDEFINED)
return;

CorPinvokeMap callConvLocal = (CorPinvokeMap)0;
if (nativeCallableInternalData)
{
callConvLocal = (CorPinvokeMap)(namedArgs[0].val.u4 << 8);
}
else
{
// Set WinAPI as the default
callConvLocal = CorPinvokeMap::pmCallConvWinapi;

CaValue* arrayOfTypes = &namedArgs[0].val;
for (ULONG i = 0; i < arrayOfTypes->arr.length; i++)
{
CaValue& typeNameValue = arrayOfTypes->arr[i];

// According to ECMA-335, type name strings are UTF-8. Since we are
// looking for type names that are equivalent in ASCII and UTF-8,
// using a const char constant is acceptable. Type name strings are
// in Fully Qualified form, so we include the ',' delimiter.
if (BeginsWith(typeNameValue.str.cbStr, typeNameValue.str.pStr, "System.Runtime.CompilerServices.CallConvCdecl,"))
{
callConvLocal = CorPinvokeMap::pmCallConvCdecl;
}
else if (BeginsWith(typeNameValue.str.cbStr, typeNameValue.str.pStr, "System.Runtime.CompilerServices.CallConvStdcall,"))
{
callConvLocal = CorPinvokeMap::pmCallConvStdcall;
}
else if (BeginsWith(typeNameValue.str.cbStr, typeNameValue.str.pStr, "System.Runtime.CompilerServices.CallConvFastcall,"))
{
callConvLocal = CorPinvokeMap::pmCallConvFastcall;
}
else if (BeginsWith(typeNameValue.str.cbStr, typeNameValue.str.pStr, "System.Runtime.CompilerServices.CallConvThiscall,"))
{
callConvLocal = CorPinvokeMap::pmCallConvThiscall;
}
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
}
}

m_callConv = (UINT16)callConvLocal;
}

// Compiles an unmanaged to managed thunk for the given signature.
Expand Down Expand Up @@ -776,7 +857,7 @@ Stub *UMThunkMarshInfo::CompileNExportThunk(LoaderHeap *pLoaderHeap, PInvokeStat
stubInfo.m_cbSrcStack = static_cast<UINT16>(m_cbActualArgSize);
stubInfo.m_cbDstStack = nStackBytes;

if (pSigInfo->GetCallConv() == pmCallConvCdecl)
if (m_callConv == pmCallConvCdecl)
{
// caller pop
m_cbRetPop = 0;
Expand All @@ -786,7 +867,7 @@ Stub *UMThunkMarshInfo::CompileNExportThunk(LoaderHeap *pLoaderHeap, PInvokeStat
// callee pop
m_cbRetPop = static_cast<UINT16>(m_cbActualArgSize);

if (pSigInfo->GetCallConv() == pmCallConvThiscall)
if (m_callConv == pmCallConvThiscall)
{
stubInfo.m_wFlags |= umtmlThisCall;
if (argit.HasRetBuffArg())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ .locals init (native int V_0)
testNativeMethod();
}

[UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
public static int CallbackViaUnmanagedCalli(int val)
{
return DoubleImpl(val);
Expand Down Expand Up @@ -587,7 +587,7 @@ .locals init (native int V_0)
Assert.AreEqual(expected, testNativeMethod());
}

[UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
public static int CallbackViaUnmanagedCalliThrows(int val)
{
throw new Exception() { HResult = CallbackThrowsErrorCode };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ private struct EnumData
}

// EnumCalendarInfoExEx callback itself.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumCalendarInfoCallback(char* lpCalendarInfoString, uint calendar, IntPtr pReserved, void* lParam)
{
ref EnumData context = ref Unsafe.As<byte, EnumData>(ref *(byte*)lParam);
Expand Down Expand Up @@ -425,7 +425,7 @@ public struct NlsEnumCalendarsData
public List<int> calendars; // list of calendars found so far
}

// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumCalendarsCallback(char* lpCalendarInfoString, uint calendar, IntPtr reserved, void* lParam)
{
ref NlsEnumCalendarsData context = ref Unsafe.As<byte, NlsEnumCalendarsData>(ref *(byte*)lParam);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ private struct EnumLocaleData
}

// EnumSystemLocaleEx callback.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumSystemLocalesProc(char* lpLocaleString, uint flags, void* contextHandle)
{
ref EnumLocaleData context = ref Unsafe.As<byte, EnumLocaleData>(ref *(byte*)contextHandle);
Expand All @@ -382,7 +382,7 @@ private static unsafe Interop.BOOL EnumSystemLocalesProc(char* lpLocaleString, u
}

// EnumSystemLocaleEx callback.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumAllSystemLocalesProc(char* lpLocaleString, uint flags, void* contextHandle)
{
ref EnumData context = ref Unsafe.As<byte, EnumData>(ref *(byte*)contextHandle);
Expand All @@ -404,7 +404,7 @@ private struct EnumData
}

// EnumTimeFormatsEx callback itself.
// [UnmanagedCallersOnly(CallingConvention = CallingConvention.StdCall)]
// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static unsafe Interop.BOOL EnumTimeCallback(char* lpTimeFormatString, void* lParam)
{
ref EnumData context = ref Unsafe.As<byte, EnumData>(ref *(byte*)lParam);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace System.Runtime.InteropServices
{
// Used for the CallingConvention named argument to the DllImport and UnmanagedCallersOnly attribute
// Used for the CallingConvention named argument to the DllImport attribute
public enum CallingConvention
{
Winapi = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace System.Runtime.InteropServices
/// * Must not be called from managed code.
/// * Must only have <see href="https://docs.microsoft.com/dotnet/framework/interop/blittable-and-non-blittable-types">blittable</see> arguments.
/// </remarks>
[AttributeUsage(AttributeTargets.Method)]
[AttributeUsage(AttributeTargets.Method, Inherited = false)]
public sealed class UnmanagedCallersOnlyAttribute : Attribute
{
public UnmanagedCallersOnlyAttribute()
Expand All @@ -25,7 +25,11 @@ public UnmanagedCallersOnlyAttribute()
/// <summary>
/// Optional. If omitted, the runtime will use the default platform calling convention.
/// </summary>
public CallingConvention CallingConvention;
/// <remarks>
/// Supplied types must be from the official "System.Runtime.CompilerServices" namespace and
/// be of the form "CallConvXXX".
/// </remarks>
public Type[]? CallConvs;

/// <summary>
/// Optional. If omitted, no named export is emitted during compilation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1006,11 +1006,11 @@ public static void RegisterForTrackerSupport(ComWrappers instance) { }
public static void RegisterForMarshalling(ComWrappers instance) { }
protected static void GetIUnknownImpl(out System.IntPtr fpQueryInterface, out System.IntPtr fpAddRef, out System.IntPtr fpRelease) { throw null; }
}
[System.AttributeUsageAttribute(System.AttributeTargets.Method)]
[System.AttributeUsageAttribute(System.AttributeTargets.Method, Inherited = false)]
public sealed class UnmanagedCallersOnlyAttribute : System.Attribute
{
public UnmanagedCallersOnlyAttribute() { }
public System.Runtime.InteropServices.CallingConvention CallingConvention;
public System.Type[]? CallConvs;
public string? EntryPoint;
}
}
Expand Down