Skip to content

Commit

Permalink
Keep target DllImport info in structured data before converting to sy…
Browse files Browse the repository at this point in the history
…ntax. (#1075)
  • Loading branch information
jkoritzinsky authored May 7, 2021
1 parent b7688d5 commit 2f89e87
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 85 deletions.
213 changes: 128 additions & 85 deletions DllImportGenerator/DllImportGenerator/DllImportGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ public void Execute(GeneratorExecutionContext context)
continue;

// Process the GeneratedDllImport attribute
DllImportStub.GeneratedDllImportData dllImportData;
AttributeSyntax dllImportAttr = this.ProcessGeneratedDllImportAttribute(methodSymbolInfo, generatedDllImportAttr, context.AnalyzerConfigOptions.GlobalOptions.GenerateForwarders(), out dllImportData);
Debug.Assert((dllImportAttr is not null) && (dllImportData is not null));
DllImportStub.GeneratedDllImportData stubDllImportData = this.ProcessGeneratedDllImportAttribute(generatedDllImportAttr);
Debug.Assert(stubDllImportData is not null);
AttributeSyntax dllImportAttr = this.CreateDllImportAttributeForTarget(stubDllImportData!, env.Options.GenerateForwarders(), methodSymbolInfo.Name);

if (dllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping))
if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping))
{
generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.BestFitMapping));
}

if (dllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ThrowOnUnmappableChar))
if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ThrowOnUnmappableChar))
{
generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar));
}
Expand All @@ -114,7 +114,7 @@ public void Execute(GeneratorExecutionContext context)
}

// Create the stub.
var dllImportStub = DllImportStub.Create(methodSymbolInfo, dllImportData!, env, generatorDiagnostics, context.CancellationToken);
var dllImportStub = DllImportStub.Create(methodSymbolInfo, stubDllImportData!, env, generatorDiagnostics, context.CancellationToken);

PrintGeneratedSource(generatedDllImports, methodSyntax, dllImportStub, dllImportAttr!);
}
Expand Down Expand Up @@ -229,13 +229,9 @@ private static bool IsGeneratedDllImportAttribute(AttributeSyntax attrSyntaxMayb
|| attrName.EndsWith(PrefixedGeneratedDllImportAttribute);
}

private AttributeSyntax ProcessGeneratedDllImportAttribute(
IMethodSymbol method,
AttributeData attrData,
bool generateForwarders,
out DllImportStub.GeneratedDllImportData dllImportData)
private DllImportStub.GeneratedDllImportData ProcessGeneratedDllImportAttribute(AttributeData attrData)
{
dllImportData = new DllImportStub.GeneratedDllImportData();
var stubDllImportData = new DllImportStub.GeneratedDllImportData();

// Found the GeneratedDllImport, but it has an error so report the error.
// This is most likely an issue with targeting an incorrect TFM.
Expand All @@ -245,138 +241,185 @@ private AttributeSyntax ProcessGeneratedDllImportAttribute(
throw new InvalidProgramException();
}

var newAttributeArgs = new List<AttributeArgumentSyntax>();

// Populate the DllImport data from the GeneratedDllImportAttribute attribute.
dllImportData.ModuleName = attrData.ConstructorArguments[0].Value!.ToString();

newAttributeArgs.Add(SyntaxFactory.AttributeArgument(SyntaxFactory.LiteralExpression(
SyntaxKind.StringLiteralExpression,
SyntaxFactory.Literal(dllImportData.ModuleName))));
stubDllImportData.ModuleName = attrData.ConstructorArguments[0].Value!.ToString();

// All other data on attribute is defined as NamedArguments.
foreach (var namedArg in attrData.NamedArguments)
{
ExpressionSyntax? expSyntaxMaybe = null;
switch (namedArg.Key)
{
default:
Debug.Fail($"An unknown member was found on {GeneratedDllImport}");
continue;
case nameof(DllImportStub.GeneratedDllImportData.BestFitMapping):
dllImportData.BestFitMapping = (bool)namedArg.Value.Value!;
expSyntaxMaybe = CreateBoolExpressionSyntax(dllImportData.BestFitMapping);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.BestFitMapping;
stubDllImportData.BestFitMapping = (bool)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.BestFitMapping;
break;
case nameof(DllImportStub.GeneratedDllImportData.CallingConvention):
dllImportData.CallingConvention = (CallingConvention)namedArg.Value.Value!;
expSyntaxMaybe = CreateEnumExpressionSyntax(dllImportData.CallingConvention);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.CallingConvention;
stubDllImportData.CallingConvention = (CallingConvention)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.CallingConvention;
break;
case nameof(DllImportStub.GeneratedDllImportData.CharSet):
dllImportData.CharSet = (CharSet)namedArg.Value.Value!;
expSyntaxMaybe = CreateEnumExpressionSyntax(dllImportData.CharSet);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.CharSet;
stubDllImportData.CharSet = (CharSet)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.CharSet;
break;
case nameof(DllImportStub.GeneratedDllImportData.EntryPoint):
dllImportData.EntryPoint = (string)namedArg.Value.Value!;
expSyntaxMaybe = CreateStringExpressionSyntax(dllImportData.EntryPoint!);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.EntryPoint;
stubDllImportData.EntryPoint = (string)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.EntryPoint;
break;
case nameof(DllImportStub.GeneratedDllImportData.ExactSpelling):
dllImportData.ExactSpelling = (bool)namedArg.Value.Value!;
expSyntaxMaybe = CreateBoolExpressionSyntax(dllImportData.ExactSpelling);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.ExactSpelling;
stubDllImportData.ExactSpelling = (bool)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.ExactSpelling;
break;
case nameof(DllImportStub.GeneratedDllImportData.PreserveSig):
dllImportData.PreserveSig = (bool)namedArg.Value.Value!;
expSyntaxMaybe = CreateBoolExpressionSyntax(dllImportData.PreserveSig);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.PreserveSig;
stubDllImportData.PreserveSig = (bool)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.PreserveSig;
break;
case nameof(DllImportStub.GeneratedDllImportData.SetLastError):
dllImportData.SetLastError = (bool)namedArg.Value.Value!;
expSyntaxMaybe = CreateBoolExpressionSyntax(dllImportData.SetLastError);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.SetLastError;
stubDllImportData.SetLastError = (bool)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.SetLastError;
break;
case nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar):
dllImportData.ThrowOnUnmappableChar = (bool)namedArg.Value.Value!;
expSyntaxMaybe = CreateBoolExpressionSyntax(dllImportData.ThrowOnUnmappableChar);
dllImportData.IsUserDefined |= DllImportStub.DllImportMember.ThrowOnUnmappableChar;
stubDllImportData.ThrowOnUnmappableChar = (bool)namedArg.Value.Value!;
stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.ThrowOnUnmappableChar;
break;
}
}

Debug.Assert(expSyntaxMaybe is not null);
return stubDllImportData;
}

// If we're generating a forwarder stub, then all parameters on the GenerateDllImport attribute
// must also be added to the generated DllImport attribute.
if (generateForwarders || PassThroughToDllImportAttribute(namedArg.Key))
{
// Defer the name equals syntax till we know the value means something. If we created
// an expression we know the key value was valid.
NameEqualsSyntax nameSyntax = SyntaxFactory.NameEquals(namedArg.Key);
newAttributeArgs.Add(SyntaxFactory.AttributeArgument(nameSyntax, null, expSyntaxMaybe!));
}
}
private AttributeSyntax CreateDllImportAttributeForTarget(DllImportStub.GeneratedDllImportData stubDllImportData, bool generateForwarders, string originalMethodName)
{
DllImportStub.GeneratedDllImportData targetDllImportData =
GetTargetDllImportDataFromStubData(stubDllImportData, generateForwarders, originalMethodName);

// If the EntryPoint property is not set, we will compute and
// add it based on existing semantics (i.e. method name).
//
// N.B. The export discovery logic is identical regardless of where
// the name is defined (i.e. method name vs EntryPoint property).
if (!dllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.EntryPoint))
var newAttributeArgs = new List<AttributeArgumentSyntax>
{
var entryPointName = SyntaxFactory.NameEquals(nameof(DllImportAttribute.EntryPoint));
AttributeArgument(LiteralExpression(
SyntaxKind.StringLiteralExpression,
Literal(targetDllImportData.ModuleName))),
AttributeArgument(
NameEquals(nameof(DllImportAttribute.EntryPoint)),
null,
CreateStringExpressionSyntax(targetDllImportData.EntryPoint))
};

// The name of the method is the entry point name to use.
var entryPointValue = CreateStringExpressionSyntax(method.Name);
newAttributeArgs.Add(SyntaxFactory.AttributeArgument(entryPointName, null, entryPointValue));
if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping))
{
var name = NameEquals(nameof(DllImportAttribute.BestFitMapping));
var value = CreateBoolExpressionSyntax(targetDllImportData.BestFitMapping);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CallingConvention))
{
var name = NameEquals(nameof(DllImportAttribute.CallingConvention));
var value = CreateEnumExpressionSyntax(targetDllImportData.CallingConvention);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CharSet))
{
var name = NameEquals(nameof(DllImportAttribute.CharSet));
var value = CreateEnumExpressionSyntax(targetDllImportData.CharSet);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ExactSpelling))
{
var name = NameEquals(nameof(DllImportAttribute.ExactSpelling));
var value = CreateBoolExpressionSyntax(targetDllImportData.ExactSpelling);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.PreserveSig))
{
var name = NameEquals(nameof(DllImportAttribute.PreserveSig));
var value = CreateBoolExpressionSyntax(targetDllImportData.PreserveSig);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.SetLastError))
{
var name = NameEquals(nameof(DllImportAttribute.SetLastError));
var value = CreateBoolExpressionSyntax(targetDllImportData.SetLastError);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ThrowOnUnmappableChar))
{
var name = NameEquals(nameof(DllImportAttribute.ThrowOnUnmappableChar));
var value = CreateBoolExpressionSyntax(targetDllImportData.ThrowOnUnmappableChar);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}

// Create new attribute
return SyntaxFactory.Attribute(
SyntaxFactory.ParseName(typeof(DllImportAttribute).FullName),
SyntaxFactory.AttributeArgumentList(SyntaxFactory.SeparatedList(newAttributeArgs)));
return Attribute(
ParseName(typeof(DllImportAttribute).FullName),
AttributeArgumentList(SeparatedList(newAttributeArgs)));

static ExpressionSyntax CreateBoolExpressionSyntax(bool trueOrFalse)
{
return SyntaxFactory.LiteralExpression(
return LiteralExpression(
trueOrFalse
? SyntaxKind.TrueLiteralExpression
: SyntaxKind.FalseLiteralExpression);
}

static ExpressionSyntax CreateStringExpressionSyntax(string str)
{
return SyntaxFactory.LiteralExpression(
return LiteralExpression(
SyntaxKind.StringLiteralExpression,
SyntaxFactory.Literal(str));
Literal(str));
}

static ExpressionSyntax CreateEnumExpressionSyntax<T>(T value) where T : Enum
{
return SyntaxFactory.MemberAccessExpression(
return MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(typeof(T).FullName),
SyntaxFactory.IdentifierName(value.ToString()));
IdentifierName(typeof(T).FullName),
IdentifierName(value.ToString()));
}

static bool PassThroughToDllImportAttribute(string argName)
static DllImportStub.GeneratedDllImportData GetTargetDllImportDataFromStubData(DllImportStub.GeneratedDllImportData stubDllImportData, bool generateForwarders, string originalMethodName)
{
// Certain fields on DllImport will prevent inlining. Their functionality should be handled by the
// generated source, so the generated DllImport declaration should not include these fields.
return argName switch
DllImportStub.DllImportMember membersToForward = DllImportStub.DllImportMember.All
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.preservesig
// If PreserveSig=false (default is true), the P/Invoke stub checks/converts a returned HRESULT to an exception.
& ~DllImportStub.DllImportMember.PreserveSig
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.setlasterror
// If SetLastError=true (default is false), the P/Invoke stub gets/caches the last error after invoking the native function.
& ~DllImportStub.DllImportMember.SetLastError;
if (generateForwarders)
{
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.preservesig
// If PreserveSig=false (default is true), the P/Invoke stub checks/converts a returned HRESULT to an exception.
nameof(DllImportStub.GeneratedDllImportData.PreserveSig) => false,
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.setlasterror
// If SetLastError=true (default is false), the P/Invoke stub gets/caches the last error after invoking the native function.
nameof(DllImportStub.GeneratedDllImportData.SetLastError) => false,
_ => true
membersToForward = DllImportStub.DllImportMember.All;
}

var targetDllImportData = new DllImportStub.GeneratedDllImportData
{
CharSet = stubDllImportData.CharSet,
BestFitMapping = stubDllImportData.BestFitMapping,
CallingConvention = stubDllImportData.CallingConvention,
EntryPoint = stubDllImportData.EntryPoint,
ModuleName = stubDllImportData.ModuleName,
ExactSpelling = stubDllImportData.ExactSpelling,
SetLastError = stubDllImportData.SetLastError,
PreserveSig = stubDllImportData.PreserveSig,
ThrowOnUnmappableChar = stubDllImportData.ThrowOnUnmappableChar,
IsUserDefined = stubDllImportData.IsUserDefined & membersToForward
};

// If the EntryPoint property is not set, we will compute and
// add it based on existing semantics (i.e. method name).
//
// N.B. The export discovery logic is identical regardless of where
// the name is defined (i.e. method name vs EntryPoint property).
if (!targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.EntryPoint))
{
targetDllImportData.EntryPoint = originalMethodName;
}

return targetDllImportData;
}
}


private class SyntaxReceiver : ISyntaxReceiver
{
public ICollection<SyntaxReference> Methods { get; } = new List<SyntaxReference>();
Expand Down
1 change: 1 addition & 0 deletions DllImportGenerator/DllImportGenerator/DllImportStub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public enum DllImportMember
PreserveSig = 1 << 5,
SetLastError = 1 << 6,
ThrowOnUnmappableChar = 1 << 7,
All = ~None
}

/// <summary>
Expand Down

0 comments on commit 2f89e87

Please sign in to comment.