Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement abstraction for marshalling direction in the generator APIs #78196

Merged
merged 2 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Microsoft.Interop
{
internal static class ComInterfaceGeneratorHelpers
{
public static MarshallingGeneratorFactoryKey<(TargetFramework, Version)> CreateGeneratorFactory(StubEnvironment env)
public static MarshallingGeneratorFactoryKey<(TargetFramework, Version)> CreateGeneratorFactory(StubEnvironment env, MarshalDirection direction)
{
IMarshallingGeneratorFactory generatorFactory;

Expand Down Expand Up @@ -44,7 +44,17 @@ internal static class ComInterfaceGeneratorHelpers
generatorFactory = new AttributedMarshallingModelGeneratorFactory(
generatorFactory,
elementFactory,
new AttributedMarshallingModelOptions(runtimeMarshallingDisabled, MarshalMode.ManagedToUnmanagedIn, MarshalMode.ManagedToUnmanagedRef, MarshalMode.ManagedToUnmanagedOut));
new AttributedMarshallingModelOptions(
runtimeMarshallingDisabled,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedIn
: MarshalMode.UnmanagedToManagedOut,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedRef
: MarshalMode.UnmanagedToManagedRef,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedOut
: MarshalMode.UnmanagedToManagedIn));

generatorFactory = new ByValueContentsMarshalKindValidator(generatorFactory);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ internal sealed record IncrementalStubGenerationContext(
MethodSignatureDiagnosticLocations DiagnosticLocation,
SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> CallingConvention,
VirtualMethodIndexData VtableIndexData,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> GeneratorFactory,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> ManagedToUnmanagedGeneratorFactory,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> UnmanagedToManagedGeneratorFactory,
ManagedTypeInfo TypeKeyType,
ManagedTypeInfo TypeKeyOwner,
SequenceEqualImmutableArray<Diagnostic> Diagnostics);
Expand Down Expand Up @@ -301,7 +302,8 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
new MethodSignatureDiagnosticLocations(syntax),
new SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax>(callConv, SyntaxEquivalentComparer.Instance),
virtualMethodIndexData,
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged),
typeKeyType,
typeKeyOwner,
new SequenceEqualImmutableArray<Diagnostic>(generatorDiagnostics.Diagnostics.ToImmutableArray()));
Expand Down Expand Up @@ -337,16 +339,16 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateMan

// Generate stub code
var stubGenerator = new ManagedToNativeVTableMethodGenerator(
methodStub.GeneratorFactory.Key.TargetFramework,
methodStub.GeneratorFactory.Key.TargetFrameworkVersion,
methodStub.ManagedToUnmanagedGeneratorFactory.Key.TargetFramework,
methodStub.ManagedToUnmanagedGeneratorFactory.Key.TargetFrameworkVersion,
methodStub.SignatureContext.ElementTypeInformation,
methodStub.VtableIndexData.SetLastError,
methodStub.VtableIndexData.ImplicitThisParameter,
(elementInfo, ex) =>
{
diagnostics.ReportMarshallingNotSupported(methodStub.DiagnosticLocation, elementInfo, ex.NotSupportedDetails);
},
methodStub.GeneratorFactory.GeneratorFactory);
methodStub.ManagedToUnmanagedGeneratorFactory.GeneratorFactory);

BlockSyntax code = stubGenerator.GenerateStubBody(
methodStub.VtableIndexData.Index,
Expand All @@ -370,19 +372,6 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateMan
methodStub.Diagnostics.Array.AddRange(diagnostics.Diagnostics));
}

private static bool ShouldVisitNode(SyntaxNode syntaxNode)
{
// We only support C# method declarations.
if (syntaxNode.Language != LanguageNames.CSharp
|| !syntaxNode.IsKind(SyntaxKind.MethodDeclaration))
{
return false;
}

// Filter out methods with no attributes early.
return ((MethodDeclarationSyntax)syntaxNode).AttributeLists.Count > 0;
}

private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method)
{
// Verify the method has no generic types or defined implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace Microsoft.Interop
{
public readonly record struct AttributedMarshallingModelOptions(bool RuntimeMarshallingDisabled, MarshalMode InMode, MarshalMode RefMode, MarshalMode OutMode);
public readonly record struct AttributedMarshallingModelOptions(bool RuntimeMarshallingDisabled, MarshalMode ManagedToUnmanagedMode, MarshalMode BidirectionalMode, MarshalMode UnmanagedToManagedMode);

public class AttributedMarshallingModelGeneratorFactory : IMarshallingGeneratorFactory
{
Expand Down Expand Up @@ -126,7 +126,7 @@ ExpressionSyntax GetExpressionForParam(TypePositionInfo paramInfo, out bool isIn
{
if (marshallingInfo is NativeLinearCollectionMarshallingInfo collectionInfo)
{
CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(collectionInfo.Marshallers, info);
CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(collectionInfo.Marshallers, info, context);
type = marshallerData.CollectionElementType;
marshallingInfo = marshallerData.CollectionElementMarshallingInfo;
}
Expand Down Expand Up @@ -200,16 +200,15 @@ private bool ValidateRuntimeMarshallingOptions(CustomTypeMarshallerData marshall
return false;
}

private CustomTypeMarshallerData GetMarshallerDataForTypePositionInfo(CustomTypeMarshallers marshallers, TypePositionInfo info)
private CustomTypeMarshallerData GetMarshallerDataForTypePositionInfo(CustomTypeMarshallers marshallers, TypePositionInfo info, StubCodeContext context)
{
if (info.IsManagedReturnPosition)
return marshallers.GetModeOrDefault(Options.OutMode);
MarshalDirection elementDirection = MarshallerHelpers.GetMarshalDirection(info, context);

return info.RefKind switch
return elementDirection switch
{
RefKind.None or RefKind.In => marshallers.GetModeOrDefault(Options.InMode),
RefKind.Ref => marshallers.GetModeOrDefault(Options.RefMode),
RefKind.Out => marshallers.GetModeOrDefault(Options.OutMode),
MarshalDirection.ManagedToUnmanaged => marshallers.GetModeOrDefault(Options.ManagedToUnmanagedMode),
MarshalDirection.Bidirectional => marshallers.GetModeOrDefault(Options.BidirectionalMode),
MarshalDirection.UnmanagedToManaged => marshallers.GetModeOrDefault(Options.UnmanagedToManagedMode),
_ => throw new UnreachableException()
};
}
Expand All @@ -218,7 +217,7 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
{
ValidateCustomNativeTypeMarshallingSupported(info, context, marshalInfo);

CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(marshalInfo.Marshallers, info);
CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(marshalInfo.Marshallers, info, context);
if (!ValidateRuntimeMarshallingOptions(marshallerData))
{
throw new MarshallingNotSupportedException(info, context)
Expand Down Expand Up @@ -373,9 +372,10 @@ private static TypeSyntax ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax(

private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo)
{
MarshalDirection elementDirection = MarshallerHelpers.GetMarshalDirection(info, context);
// Marshalling out or return parameter, but no out marshaller is specified
if ((info.RefKind == RefKind.Out || info.IsManagedReturnPosition)
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.OutMode))
if (elementDirection == MarshalDirection.UnmanagedToManaged
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.UnmanagedToManagedMode))
{
throw new MarshallingNotSupportedException(info, context)
{
Expand All @@ -384,7 +384,7 @@ private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info,
}

// Marshalling ref parameter, but no ref marshaller is specified
if (info.RefKind == RefKind.Ref && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.RefMode))
if (elementDirection == MarshalDirection.Bidirectional && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.BidirectionalMode))
{
throw new MarshallingNotSupportedException(info, context)
{
Expand All @@ -393,20 +393,8 @@ private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info,
}

// Marshalling in parameter, but no in marshaller is specified
if (info.RefKind == RefKind.In
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.InMode))
{
throw new MarshallingNotSupportedException(info, context)
{
NotSupportedDetails = SR.Format(SR.ManagedToUnmanagedMissingRequiredMarshaller, marshalInfo.EntryPointType.FullTypeName)
};
}

// Marshalling by value, but no in marshaller is specified
if (!info.IsByRef
&& !info.IsManagedReturnPosition
&& context.SingleFrameSpansNativeContext
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.InMode))
if (elementDirection == MarshalDirection.ManagedToUnmanaged
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.ManagedToUnmanagedMode))
{
throw new MarshallingNotSupportedException(info, context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
yield break;
}

MarshalDirection elementMarshalling = MarshallerHelpers.GetMarshalDirection(info, context);

switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
break;
case StubCodeContext.Stage.Marshal:
if (info.RefKind == RefKind.Ref)
if (elementMarshalling is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional && info.IsByRef)
{
yield return ExpressionStatement(
AssignmentExpression(
Expand All @@ -82,11 +84,14 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont

break;
case StubCodeContext.Stage.Unmarshal:
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
IdentifierName(nativeIdentifier)));
if (elementMarshalling is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional && info.IsByRef)
{
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
IdentifierName(nativeIdentifier)));
}
break;
default:
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, Stu

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context);
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
break;
case StubCodeContext.Stage.Marshal:
// <nativeIdentifier> = (<nativeType>)(<managedIdentifier> ? _trueValue : _falseValue);
if (info.RefKind != RefKind.Out)
if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
{
yield return ExpressionStatement(
AssignmentExpression(
Expand All @@ -75,7 +76,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont

break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
{
// <managedIdentifier> = <nativeIdentifier> == _trueValue;
// or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,30 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
yield break;
}

MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context);

switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
break;
case StubCodeContext.Stage.Marshal:
if ((info.IsByRef && info.RefKind != RefKind.Out) || !context.SingleFrameSpansNativeContext)
if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
{
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
IdentifierName(managedIdentifier)));
// There's an implicit conversion from char to ushort,
// so we simplify the generated code to just pass the char value directly
if (info.IsByRef)
{
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
IdentifierName(managedIdentifier)));
}
}

break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
{
yield return ExpressionStatement(
AssignmentExpression(
Expand Down
Loading