Skip to content

Commit

Permalink
Update SafeHandle codegen to match the approved API. (#570)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkoritzinsky authored Jan 20, 2021
1 parent dd81061 commit 0c64a2a
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 54 deletions.
22 changes: 1 addition & 21 deletions DllImportGenerator/Ancillary.Interop/MarshalEx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,12 @@ namespace System.Runtime.InteropServices
/// </summary>
public static class MarshalEx
{
/// <summary>
/// Create an instance of the given <typeparamref name="TSafeHandle"/>.
/// </summary>
/// <typeparam name="TSafeHandle">Type of the SafeHandle</typeparam>
/// <returns>New instance of <typeparamref name="TSafeHandle"/></returns>
/// <remarks>
/// The <typeparamref name="TSafeHandle"/> must be non-abstract and have a parameterless constructor.
/// </remarks>
public static TSafeHandle CreateSafeHandle<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor | DynamicallyAccessedMemberTypes.NonPublicConstructors)]TSafeHandle>()
where TSafeHandle : SafeHandle
{
if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance | BindingFlags.Instance, null, Type.EmptyTypes, null) == null)
{
throw new MissingMemberException($"The safe handle type '{typeof(TSafeHandle).FullName}' must be a non-abstract type with a parameterless constructor.");
}

TSafeHandle safeHandle = (TSafeHandle)Activator.CreateInstance(typeof(TSafeHandle), nonPublic: true)!;
return safeHandle;
}

/// <summary>
/// Sets the handle of <paramref name="safeHandle"/> to the specified <paramref name="handle"/>.
/// </summary>
/// <param name="safeHandle"><see cref="SafeHandle"/> instance to update</param>
/// <param name="handle">Pre-existing handle</param>
public static void SetHandle(SafeHandle safeHandle, IntPtr handle)
public static void InitHandle(SafeHandle safeHandle, IntPtr handle)
{
typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace DllImportGenerator.IntegrationTests
{
partial class NativeExportsNE
{
public class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
public partial class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
{
private NativeExportsSafeHandle() : base(ownsHandle: true)
{ }
Expand All @@ -18,6 +18,12 @@ protected override bool ReleaseHandle()
Assert.True(didRelease);
return didRelease;
}

public static NativeExportsSafeHandle CreateNewHandle() => AllocateHandle();


[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "alloc_handle")]
private static partial NativeExportsSafeHandle AllocateHandle();
}

[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "alloc_handle")]
Expand Down Expand Up @@ -48,6 +54,14 @@ public void ReturnValue_CreatesSafeHandle()
Assert.False(handle.IsInvalid);
}

[Fact]
public void ReturnValue_CreatesSafeHandle_DirectConstructorCall()
{
using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.NativeExportsSafeHandle.CreateNewHandle();
Assert.False(handle.IsClosed);
Assert.False(handle.IsInvalid);
}

[Fact]
public void ByValue_CorrectlyUnwrapsHandle()
{
Expand Down
10 changes: 10 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -919,5 +919,15 @@ public IntStructWrapperNative(IntStructWrapper managed)
public IntStructWrapper ToManaged() => new IntStructWrapper { Value = value };
}
";

public static string SafeHandleWithCustomDefaultConstructorAccessibility(bool privateCtor) => BasicParametersAndModifiers("MySafeHandle") + $@"
class MySafeHandle : SafeHandle
{{
{(privateCtor ? "private" : "public")} MySafeHandle() : base(System.IntPtr.Zero, true) {{ }}
public override bool IsInvalid => handle == System.IntPtr.Zero;
protected override bool ReleaseHandle() => true;
}}";
}
}
2 changes: 2 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()

// SafeHandle
yield return new[] { CodeSnippets.BasicParametersAndModifiers("Microsoft.Win32.SafeHandles.SafeFileHandle") };
yield return new[] { CodeSnippets.SafeHandleWithCustomDefaultConstructorAccessibility(privateCtor: false) };
yield return new[] { CodeSnippets.SafeHandleWithCustomDefaultConstructorAccessibility(privateCtor: true) };

// PreserveSig
yield return new[] { CodeSnippets.PreserveSigFalseVoidReturn };
Expand Down
9 changes: 5 additions & 4 deletions DllImportGenerator/DllImportGenerator/DllImportStub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ public static DllImportStub Create(
// Since we're generating source for the method, we know that the current type
// has to be declared in source.
TypeDeclarationSyntax typeDecl = (TypeDeclarationSyntax)currType.DeclaringSyntaxReferences[0].GetSyntax();
// Remove current members and attributes so we don't double declare them.
// Remove current members, attributes, and base list so we don't double declare them.
typeDecl = typeDecl.WithMembers(List<MemberDeclarationSyntax>())
.WithAttributeLists(List<AttributeListSyntax>());
.WithAttributeLists(List<AttributeListSyntax>())
.WithBaseList(null);

containingTypes.Add(typeDecl);

Expand All @@ -162,7 +163,7 @@ public static DllImportStub Create(
for (int i = 0; i < method.Parameters.Length; i++)
{
var param = method.Parameters[i];
var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics);
var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics, method.ContainingType);
typeInfo = typeInfo with
{
ManagedIndex = i,
Expand All @@ -171,7 +172,7 @@ public static DllImportStub Create(
paramsTypeInfo.Add(typeInfo);
}

TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics);
TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics, method.ContainingType);
retTypeInfo = retTypeInfo with
{
ManagedIndex = TypePositionInfo.ReturnIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,39 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
SingletonSeparatedList(
VariableDeclarator(addRefdIdentifier)
.WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression))))));

}

var safeHandleCreationExpression = ((SafeHandleMarshallingInfo)info.MarshallingAttributeInfo).AccessibleDefaultConstructor
? (ExpressionSyntax)ObjectCreationExpression(info.ManagedType.AsTypeSyntax(), ArgumentList(), initializer: null)
: CastExpression(
info.ManagedType.AsTypeSyntax(),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Activator),
IdentifierName("CreateInstance")))
.WithArgumentList(
ArgumentList(
SeparatedList(
new []{
Argument(
TypeOfExpression(
info.ManagedType.AsTypeSyntax())),
Argument(
LiteralExpression(
SyntaxKind.TrueLiteralExpression))
.WithNameColon(
NameColon(
IdentifierName("nonPublic")))
}))));

if (info.IsManagedReturnPosition)
{
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.MarshalEx(options)),
GenericName(Identifier("CreateSafeHandle"),
TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))),
ArgumentList())));
safeHandleCreationExpression
));
}
else if (info.IsByRef && info.RefKind != RefKind.In)
{
Expand All @@ -105,13 +125,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
info.ManagedType.AsTypeSyntax(),
SingletonSeparatedList(
VariableDeclarator(newHandleObjectIdentifier)
.WithInitializer(EqualsValueClause(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.MarshalEx(options)),
GenericName(Identifier("CreateSafeHandle"),
TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))),
ArgumentList()))))));
.WithInitializer(EqualsValueClause(safeHandleCreationExpression)))));
if (info.RefKind != RefKind.Out)
{
yield return LocalDeclarationStatement(
Expand Down Expand Up @@ -168,7 +182,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.MarshalEx(options)),
IdentifierName("SetHandle")),
IdentifierName("InitHandle")),
ArgumentList(SeparatedList(
new []
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ internal sealed record GeneratedNativeMarshallingAttributeInfo(
/// <summary>
/// The type of the element is a SafeHandle-derived type with no marshalling attributes.
/// </summary>
internal sealed record SafeHandleMarshallingInfo : MarshallingInfo;
internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor) : MarshallingInfo;


/// <summary>
Expand Down
2 changes: 2 additions & 0 deletions DllImportGenerator/DllImportGenerator/TypeNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ static class TypeNames
public const string System_Span_Metadata = "System.Span`1";
public const string System_Span = "System.Span";

public const string System_Activator = "System.Activator";

public const string System_Runtime_InteropServices_StructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute";

public const string System_Runtime_InteropServices_MarshalAsAttribute = "System.Runtime.InteropServices.MarshalAsAttribute";
Expand Down
36 changes: 24 additions & 12 deletions DllImportGenerator/DllImportGenerator/TypePositionInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ private TypePositionInfo()

public MarshallingInfo MarshallingAttributeInfo { get; init; }

public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
{
var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics);
var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics, scopeSymbol);
var typeInfo = new TypePositionInfo()
{
ManagedType = paramSymbol.Type,
Expand All @@ -98,9 +98,9 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol,
return typeInfo;
}

public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
{
var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics);
var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics, scopeSymbol);
var typeInfo = new TypePositionInfo()
{
ManagedType = type,
Expand All @@ -127,7 +127,7 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo m
return typeInfo;
}

private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
{
// Look at attributes passed in - usage specific.
foreach (var attrData in attributes)
Expand All @@ -137,7 +137,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<
if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass))
{
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute
return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics);
return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics, scopeSymbol);
}
else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass))
{
Expand Down Expand Up @@ -167,7 +167,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<

// If the type doesn't have custom attributes that dictate marshalling,
// then consider the type itself.
if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, out MarshallingInfo infoMaybe))
if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, scopeSymbol, out MarshallingInfo infoMaybe))
{
return infoMaybe;
}
Expand All @@ -183,7 +183,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<

return NoMarshallingInfo.Instance;

static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
{
object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!;
UnmanagedType unmanagedType = unmanagedTypeObj is short
Expand Down Expand Up @@ -252,7 +252,7 @@ static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrDat
}
else if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType })
{
elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics);
elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics, scopeSymbol);
}

return new ArrayMarshalAsInfo(
Expand Down Expand Up @@ -307,21 +307,33 @@ static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol ty
NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null);
}

static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, out MarshallingInfo marshallingInfo)
static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, out MarshallingInfo marshallingInfo)
{
var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!);
if (conversion.Exists
&& conversion.IsImplicit
&& conversion.IsReference
&& !type.IsAbstract)
{
marshallingInfo = new SafeHandleMarshallingInfo();
bool hasAccessibleDefaultConstructor = false;
if (type is INamedTypeSymbol named && named.InstanceConstructors.Length > 0)
{
foreach (var ctor in named.InstanceConstructors)
{
if (ctor.Parameters.Length == 0)
{
hasAccessibleDefaultConstructor = compilation.IsSymbolAccessibleWithin(ctor, scopeSymbol);
break;
}
}
}
marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor);
return true;
}

if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType })
{
marshallingInfo = new ArrayMarshallingInfo(GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics));
marshallingInfo = new ArrayMarshallingInfo(GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics, scopeSymbol));
return true;
}
marshallingInfo = NoMarshallingInfo.Instance;
Expand Down

0 comments on commit 0c64a2a

Please sign in to comment.