Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable derived COM interfaces to hide base methods with the new keyword #101577

Merged
merged 7 commits into from
May 6, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfa
/// </summary>
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);

/// <summary>
/// COM methods that require shadowing declarations on the derived interface.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);

/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.WithComparer(SyntaxEquivalentComparer.Instance)
.SelectNormalized();

var shadowingMethods = interfaceAndMethodsContexts
var shadowingMethodDeclarations = interfaceAndMethodsContexts
.Select((data, ct) =>
{
var context = data.Interface.Info;
Expand Down Expand Up @@ -163,7 +163,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Zip(nativeToManagedVtableMethods)
.Zip(nativeToManagedVtables)
.Zip(iUnknownDerivedAttributeApplication)
.Zip(shadowingMethods)
.Zip(shadowingMethodDeclarations)
.Select(static (data, ct) =>
{
var ((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data;
Expand Down Expand Up @@ -352,7 +352,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M

var containingSyntaxContext = new ContainingSyntaxContext(syntax);

var methodSyntaxTemplate = new ContainingSyntax(syntax.Modifiers.StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);

ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
suppressGCTransitionAttribute,
Expand Down Expand Up @@ -423,11 +423,42 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
var methodList = ImmutableArray.CreateBuilder<ComMethodContext>();
while (methodIndex < methods.Length && methods[methodIndex].OwningInterface == iface)
{
var method = methods[methodIndex];
if (method.MethodInfo.IsUserDefinedShadowingMethod)
{
bool shadowFound = false;
int shadowIndex = -1;
// Don't remove method, but make it so that it doesn't generate any stubs
for (int i = methodList.Count - 1; i > -1; i--)
{
var potentialShadowedMethod = methodList[i];
if (MethodEquals(method, potentialShadowedMethod))
{
shadowFound = true;
shadowIndex = i;
break;
}
}
if (shadowFound)
{
methodList[shadowIndex].IsHiddenOnDerivedInterface = true;
}
// We might not find the shadowed method if it's defined on a non-GeneratedComInterface-attributed interface. Thats okay and we can disregard it.
}
methodList.Add(methods[methodIndex++]);
}
contextList.Add(new(iface, methodList.ToImmutable().ToSequenceEqual()));
}
return contextList.ToImmutable();

static bool MethodEquals(ComMethodContext a, ComMethodContext b)
{
if (a.MethodInfo.MethodName != b.MethodInfo.MethodName)
return false;
if (a.GenerationContext.SignatureContext.ManagedParameters.SequenceEqual(b.GenerationContext.SignatureContext.ManagedParameters))
return true;
return false;
}
}

private static readonly InterfaceDeclarationSyntax ImplementationInterfaceTemplate = InterfaceDeclaration("InterfaceImplementation")
Expand All @@ -436,12 +467,12 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
{
var definingType = interfaceGroup.Interface.Info.Type;
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
.WithExplicitInterfaceSpecifier(
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub);
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
return ImplementationInterfaceTemplate
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
.WithMembers(
Expand Down Expand Up @@ -560,7 +591,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
ParenthesizedExpression(
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3))))))));
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
}

var vtableSlotAssignments = VirtualMethodPointerStubGenerator.GenerateVirtualMethodTableSlotAssignments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, In

public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;

public bool IsHiddenOnDerivedInterface { get; set; }

private GeneratedMethodContextBase? _managedToUnmanagedStub;

public GeneratedMethodContextBase ManagedToUnmanagedStub => _managedToUnmanagedStub ??= CreateManagedToUnmanagedStub();

private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional))
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
{
return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
}
Expand All @@ -89,7 +91,7 @@ private GeneratedMethodContextBase CreateManagedToUnmanagedStub()

private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional))
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
{
return new SkippedStubContext(GenerationContext.OriginalDefiningType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace Microsoft.Interop
internal sealed record ComMethodInfo(
MethodDeclarationSyntax Syntax,
string MethodName,
SequenceEqualImmutableArray<AttributeInfo> Attributes)
SequenceEqualImmutableArray<AttributeInfo> Attributes,
bool IsUserDefinedShadowingMethod)
{
/// <summary>
/// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
Expand Down Expand Up @@ -123,7 +124,9 @@ internal sealed record ComMethodInfo(
{
attributeInfos.Add(AttributeInfo.From(attr));
}
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual());

bool shadowsBaseMethod = comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.NewKeyword);
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual(), shadowsBaseMethod);
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((comMethodInfo, method));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,39 @@ public void IStringArrayMarshallingFails_Failing()
obj.ByValueOutParam(strings);
});
}

[Fact]
public unsafe void IHideWorksAsExpected()
{
IHide obj = CreateWrapper<HideBaseMethods, IHide3>();

// IHide.SameMethod should be index 3
Assert.Equal(3, obj.SameMethod());
Assert.Equal(4, obj.DifferentMethod());

IHide2 obj2 = (IHide2)obj;

// IHide2.SameMethod should be index 5
Assert.Equal(5, obj2.SameMethod());
Assert.Equal(4, obj2.DifferentMethod());
Assert.Equal(6, obj2.DifferentMethod2());

IHide3 obj3 = (IHide3)obj;
// IHide3.SameMethod should be index 7
Assert.Equal(7, obj3.SameMethod());
Assert.Equal(4, obj3.DifferentMethod());
Assert.Equal(6, obj3.DifferentMethod2());
Assert.Equal(8, obj3.DifferentMethod3());

// Ensure each VTable method points to the correct method on HideBaseMethods
for (int i = 3; i < 9; i++)
{
var (__this, __vtable_native) = ((global::System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)obj3).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IHide3));
int __retVal;
int __invokeRetVal;
__invokeRetVal = ((delegate* unmanaged[MemberFunction]<void*, int*, int>)__vtable_native[i])(__this, &__retVal);
Assert.Equal(i, __retVal);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -109,8 +110,6 @@ public static StatelessCollectionAllShapes<TManagedElement> AllocateContainerFor
}
""";

public static readonly string DisableRuntimeMarshalling = "[assembly:System.Runtime.CompilerServices.DisableRuntimeMarshalling]";
public static readonly string UsingSystemRuntimeInteropServicesMarshalling = "using System.Runtime.InteropServices.Marshalling;";
public const string IntMarshaller = """
[CustomMarshaller(typeof(int), MarshalMode.Default, typeof(IntMarshaller))]
internal static class IntMarshaller
Expand Down Expand Up @@ -433,6 +432,87 @@ partial interface INativeAPI
{{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
""";

public string DerivedComInterfaceTypeWithShadowingMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method();
}
""";

public string DerivedComInterfaceTypeShadowsNonComMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
interface IOtherInterface
{
void Method2();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method2();
}
""";

public string DerivedComInterfaceTypeShadowsComAndNonComMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
interface IOtherInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method();
}
""";

public string DerivedComInterfaceTypeTwoLevelShadows => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface1: IComInterface
{
new void Method1();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface1
{
new void Method();
}
""";

public string DerivedComInterfaceType => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Microsoft.DotNet.XUnitExtensions.Attributes;
using Microsoft.Interop.UnitTests;
using Xunit;
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator>;
Expand Down Expand Up @@ -336,6 +335,10 @@ public static IEnumerable<object[]> ComInterfaceSnippetsToCompile()
{
CodeSnippets codeSnippets = new(new GeneratedComInterfaceAttributeProvider());
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceType };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeWithShadowingMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsNonComMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsComAndNonComMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeTwoLevelShadows};
yield return new object[] { ID(), codeSnippets.DerivedWithParametersDeclaredInOtherNamespace };
yield return new object[] { ID(), codeSnippets.ComInterfaceParameters };
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid("023EA72A-ECAA-4B65-9D96-2122CFADE16C")]
internal partial interface IHide
{
int SameMethod();
int DifferentMethod();
}

[GeneratedComInterface]
[Guid("5293B3B1-4994-425C-803E-A21A5011E077")]
internal partial interface IHide2 : IHide
{
new int SameMethod();
int DifferentMethod2();
}

internal interface UnrelatedInterfaceWithSameMethod
{
int SameMethod();
int DifferentMethod3();
}

[GeneratedComInterface]
[Guid("5DD35432-4987-488D-94F1-7682D7E4405C")]
internal partial interface IHide3 : IHide2, UnrelatedInterfaceWithSameMethod
{
new int SameMethod();
new int DifferentMethod3();
}

[GeneratedComClass]
[Guid("2D36BD6D-C80E-4F00-86E9-8D1B4A0CB59A")]
/// <summary>
/// Implements IHides3 and returns the expected VTable index for each method.
/// </summary>
internal partial class HideBaseMethods : IHide3
{
int IHide.SameMethod() => 3;
int IHide.DifferentMethod() => 4;
int IHide2.SameMethod() => 5;
int IHide2.DifferentMethod2() => 6;
int IHide3.SameMethod() => 7;
int IHide3.DifferentMethod3() => 8;
int UnrelatedInterfaceWithSameMethod.SameMethod() => -1;
int UnrelatedInterfaceWithSameMethod.DifferentMethod3() => -1;
}
}
Loading