Skip to content

Commit

Permalink
Remove the TypeKey concept as the primary user for the concept is una…
Browse files Browse the repository at this point in the history
…ble to use it effectively. (#79418)
  • Loading branch information
jkoritzinsky authored Jan 19, 2023
1 parent fa19917 commit 170587e
Show file tree
Hide file tree
Showing 13 changed files with 167 additions and 234 deletions.
53 changes: 24 additions & 29 deletions docs/design/libraries/ComInterfaceGenerator/VTableStubs.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,29 @@ public readonly ref struct VirtualMethodTableInfo
}
}

public interface IUnmanagedVirtualMethodTableProvider<T> where T : IEquatable<T>
public interface IUnmanagedVirtualMethodTableProvider
{
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(Type type);

public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedInterfaceType>()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<T>
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType>
{
return GetVirtualMethodTableInfoForKey(TUnmanagedInterfaceType.TypeKey);
// Dispatch from a non-virtual generic to a virtual non-generic with System.Type
// to avoid generic virtual method dispatch, which is very slow.
return GetVirtualMethodTableInfoForKey(typeof(TUnmanagedInterfaceType));
}
}

public interface IUnmanagedInterfaceType<T> where T : IEquatable<T>
public interface IUnmanagedInterfaceType<TUnmanagedInterfaceType> where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType>
{
public abstract static T TypeKey { get; }
}
```

## Required API Shapes

The user will be required to implement `IUnmanagedVirtualMethodTableProvider<T>` on the type that provides the method tables, and `IUnmanagedInterfaceType<T>` on the type that defines the unmanaged interface. The `T` types must match between the two interfaces. This mechanism is designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`.
The user will be required to implement `IUnmanagedVirtualMethodTableProvider` on the type that provides the method tables, and `IUnmanagedInterfaceType<TUnmanagedInterfaceType>` on the type that defines the unmanaged interface. The `TUnmanagedInterfaceType` follows the same design principles as the generic math designs as somewhat of a "self" type to enable us to use the derived interface type in any additional APIs we add to support unmanaged-to-managed stubs.

Previously, each of these interface types were also generic on another type `T`. The `T` types were required to match between the two interfaces. This mechanism was designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`. However, practical implementation showed that providing just a "type key" was not enough information to cover any non-trivial scenarios (like COM) efficiently without effectively forcing a two-level lookup model or hard-coding type support in the `IUnmanagedVirtualMethodTableProvider<T>` implementation. Additionally, we determined that using reflection to get to attributes is considered "okay" and using generic attributes would enable APIs that build on this model like COM to effectively retrieve information from the `System.Type` instance without causing additional problems.

## Example Usage

Expand Down Expand Up @@ -160,11 +163,8 @@ using System.Runtime.InteropServices;
[assembly:DisableRuntimeMarshalling]
// Define the interface of the native API
partial interface INativeAPI : IUnmanagedInterfaceType<NoCasting>
partial interface INativeAPI : IUnmanagedInterfaceType<INativeAPI>
{
// There is no concept of casting for this API, but providing a type key is still required by the generator.
// Use an empty readonly record struct to provide a type that implements IEquatable<T> but contains no data.
static NoCasting IUnmanagedInterfaceType.TypeKey => default;
[VirtualMethodIndex(0, ImplicitThisParameter = false, Direction = CustomTypeMarshallerDirection.In)]
int GetVersion();
Expand All @@ -176,11 +176,8 @@ partial interface INativeAPI : IUnmanagedInterfaceType<NoCasting>
int Multiply(int x, int y);
}
// Define the key for native "casting" support for our scenario
readonly record struct NoCasting {}
// Define our runtime wrapper type for the native interface.
unsafe class NativeAPI : IUnmanagedVirtualMethodTableProvider<NoCasting>, INativeAPI.Native
unsafe class NativeAPI : IUnmanagedVirtualMethodTableProvider, INativeAPI.Native
{
private CNativeAPI* _nativeAPI;
Expand All @@ -192,7 +189,7 @@ unsafe class NativeAPI : IUnmanagedVirtualMethodTableProvider<NoCasting>, INativ
}
}
VirtualMethodTableInfo IUnmanagedVirtualMethodTableProvider<NoCasting>.GetVirtualMethodTableInfoForKey(NoCasting _)
VirtualMethodTableInfo IUnmanagedVirtualMethodTableProvider.GetVirtualMethodTableInfoForKey(Type _)
{
return new(IntPtr.Zero, MemoryMarshal.Cast<CNativeAPI, IntPtr>(new ReadOnlySpan<CNativeAPI>(_nativeAPI, 1)));
}
Expand Down Expand Up @@ -229,7 +226,7 @@ partial interface INativeAPI
{
int INativeAPI.GetVersion()
{
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
int retVal;
retVal = ((delegate* unmanaged<int>)vtable[0])();
return retVal;
Expand All @@ -242,7 +239,7 @@ partial interface INativeAPI
{
int INativeAPI.Add(int x, int y)
{
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
int retVal;
retVal = ((delegate* unmanaged<int, int, int>)vtable[1])(x, y);
return retVal;
Expand All @@ -255,7 +252,7 @@ partial interface INativeAPI
{
int INativeAPI.Multiply(int x, int y)
{
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
int retVal;
retVal = ((delegate* unmanaged<int, int, int>)vtable[2])(x, y);
return retVal;
Expand All @@ -266,7 +263,7 @@ partial interface INativeAPI
// LibraryImport-generated code omitted for brevity
```

As this generator is primarily designed to provide building blocks for future work, it has a larger requirement on user-written code. In particular, this generator does not provide any support for authoring a runtime wrapper object that stores the native pointers for the underlying object or the virtual method table. However, this lack of support also provides significant flexibility for developers. The only requirement for the runtime wrapper object type is that it implements `IUnmanagedVirtualMethodTableProvider<T>` with a `T` matching the `TypeKey` type of the native interface.
As this generator is primarily designed to provide building blocks for future work, it has a larger requirement on user-written code. In particular, this generator does not provide any support for authoring a runtime wrapper object that stores the native pointers for the underlying object or the virtual method table. However, this lack of support also provides significant flexibility for developers. The only requirement for the runtime wrapper object type is that it implements `IUnmanagedVirtualMethodTableProvider`.

The emitted interface implementation can be used in two ways:

Expand All @@ -290,10 +287,8 @@ struct IUnknown
using System;
using System.Runtime.InteropServices;

interface IUnknown: IUnmanagedInterfaceType<Guid>
interface IUnknown: IUnmanagedInterfaceType<IUnknown>
{
static Guid IUnmanagedTypeInterfaceType<Guid>.TypeKey => Guid.Parse("00000000-0000-0000-C000-000000000046");

[UnmanagedCallConv(CallConvs = new[] { typeof(CallConvStdcall), typeof(CallConvMemberFunction) })]
[VirtualMethodIndex(0)]
int QueryInterface(in Guid riid, out IntPtr ppvObject);
Expand All @@ -307,7 +302,7 @@ interface IUnknown: IUnmanagedInterfaceType<Guid>
uint Release();
}

class BaseIUnknownComObject : IUnmanagedVirtualMethodTableProvider<Guid>, IDynamicInterfaceCastable
class BaseIUnknownComObject : IUnmanagedVirtualMethodTableProvider, IDynamicInterfaceCastable
{
private IntPtr _unknownPtr;

Expand All @@ -316,9 +311,9 @@ class BaseIUnknownComObject : IUnmanagedVirtualMethodTableProvider<Guid>, IDynam
_unknownPtr = unknown;
}

unsafe VirtualMethodTableInfo IUnmanagedVirtualMethodTableProvider<Guid>.GetVirtualMethodTableInfoForKey(Guid iid)
unsafe VirtualMethodTableInfo IUnmanagedVirtualMethodTableProvider.GetVirtualMethodTableInfoForKey(Type type)
{
if (iid == IUnknown.TypeKey)
if (type == typeof(IUnknown))
{
return new VirtualMethodTableInfo(_unknownPtr, new ReadOnlySpan<IntPtr>(**(IntPtr***)_unknownPtr), 3);
}
Expand Down Expand Up @@ -358,7 +353,7 @@ partial interface IUnknown
{
int IUnknown.QueryInterface(in Guid riid, out IntPtr ppvObject)
{
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<IUnknown>();
int retVal;
fixed (Guid* riid__gen_native = &riid)
fixed (IntPtr* ppvObject__gen_native = &ppvObject)
Expand All @@ -375,7 +370,7 @@ partial interface IUnknown
{
uint IUnknown.AddRef()
{
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<IUnknown>();
uint retVal;
retVal = ((delegate* unmanaged[Stdcall, MemberFunction]<IntPtr, uint>)vtable[1])(thisPtr);
return retVal;
Expand All @@ -388,7 +383,7 @@ partial interface IUnknown
{
uint IUnknown.Release()
{
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<IUnknown>();
uint retVal;
retVal = ((delegate* unmanaged[Stdcall, MemberFunction]<IntPtr, uint>)vtable[2])(thisPtr);
return retVal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ public ManagedToNativeVTableMethodGenerator(
/// <remarks>
/// The generated code assumes it will be in an unsafe context.
/// </remarks>
public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv, TypeSyntax containingTypeName, ManagedTypeInfo typeKeyType)
public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv, TypeSyntax containingTypeName)
{
var setupStatements = new List<StatementSyntax>
{
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider<<typeKeyType>>)this).GetVirtualMethodTableInfoForKey<<containingTypeName>>();
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<<containingTypeName>>();
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
Expand All @@ -119,11 +119,7 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnm
SyntaxKind.SimpleMemberAccessExpression,
ParenthesizedExpression(
CastExpression(
GenericName(
Identifier(TypeNames.IUnmanagedVirtualMethodTableProvider))
.WithTypeArgumentList(
TypeArgumentList(
SingletonSeparatedList(typeKeyType.Syntax))),
ParseTypeName(TypeNames.IUnmanagedVirtualMethodTableProvider),
ThisExpression())),
GenericName(
Identifier("GetVirtualMethodTableInfoForKey"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

namespace Microsoft.Interop
{
internal sealed record NativeThisInfo(ManagedTypeInfo TypeKeyType) : MarshallingInfo;
internal sealed record NativeThisInfo : MarshallingInfo
{
public static readonly NativeThisInfo Instance = new();
}

internal sealed class NativeToManagedThisMarshallerFactory : IMarshallingGeneratorFactory
{
Expand All @@ -20,14 +23,10 @@ public NativeToManagedThisMarshallerFactory(IMarshallingGeneratorFactory inner)
}

public IMarshallingGenerator Create(TypePositionInfo info, StubCodeContext context)
=> info.MarshallingAttributeInfo is NativeThisInfo(ManagedTypeInfo typeKeyType) ? new Marshaller(typeKeyType) : _inner.Create(info, context);
=> info.MarshallingAttributeInfo is NativeThisInfo ? new Marshaller() : _inner.Create(info, context);

private sealed class Marshaller : IMarshallingGenerator
{
private readonly ManagedTypeInfo _typeKeyType;

public Marshaller(ManagedTypeInfo typeKeyType) => _typeKeyType = typeKeyType;

public ManagedTypeInfo AsNativeType(TypePositionInfo info) => new PointerTypeInfo("void*", "void*", false);
public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
Expand All @@ -44,10 +43,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
IdentifierName(managedIdentifier),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
GenericName(Identifier(TypeNames.IUnmanagedVirtualMethodTableProvider),
TypeArgumentList(
SingletonSeparatedList(
_typeKeyType.Syntax))),
ParseTypeName(TypeNames.IUnmanagedVirtualMethodTableProvider),
GenericName(Identifier("GetObjectForUnmanagedWrapper"),
TypeArgumentList(
SingletonSeparatedList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ internal sealed record IncrementalStubGenerationContext(
MarshallingInfo ExceptionMarshallingInfo,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> ManagedToUnmanagedGeneratorFactory,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> UnmanagedToManagedGeneratorFactory,
ManagedTypeInfo TypeKeyType,
ManagedTypeInfo TypeKeyOwner,
SequenceEqualImmutableArray<Diagnostic> Diagnostics);

Expand Down Expand Up @@ -348,20 +347,15 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD

ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = GenerateCallConvSyntaxFromAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute);

var typeKeyOwner = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
ManagedTypeInfo typeKeyType = SpecialTypeInfo.Byte;
var interfaceType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);

INamedTypeSymbol? iUnmanagedInterfaceTypeInstantiation = symbol.ContainingType.AllInterfaces.FirstOrDefault(iface => SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, iUnmanagedInterfaceTypeType));
if (iUnmanagedInterfaceTypeInstantiation is null)
INamedTypeSymbol expectedUnmanagedInterfaceType = iUnmanagedInterfaceTypeType.Construct(symbol.ContainingType);

bool implementsIUnmanagedInterfaceOfSelf = symbol.ContainingType.AllInterfaces.Any(iface => SymbolEqualityComparer.Default.Equals(iface, expectedUnmanagedInterfaceType));
if (!implementsIUnmanagedInterfaceOfSelf)
{
// TODO: Report invalid configuration
}
else
{
// The type key is the second generic type parameter, so we need to get the info for the
// second argument.
typeKeyType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(iUnmanagedInterfaceTypeInstantiation.TypeArguments[1]);
}

MarshallingInfo exceptionMarshallingInfo = CreateExceptionMarshallingInfo(virtualMethodIndexAttr, symbol, environment.Compilation, generatorDiagnostics, virtualMethodIndexData);

Expand All @@ -375,8 +369,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
exceptionMarshallingInfo,
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged),
typeKeyType,
typeKeyOwner,
interfaceType,
new SequenceEqualImmutableArray<Diagnostic>(generatorDiagnostics.Diagnostics.ToImmutableArray()));
}

Expand Down Expand Up @@ -442,8 +435,7 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateMan
BlockSyntax code = stubGenerator.GenerateStubBody(
methodStub.VtableIndexData.Index,
methodStub.CallingConvention.Array,
methodStub.TypeKeyOwner.Syntax,
methodStub.TypeKeyType);
methodStub.TypeKeyOwner.Syntax);

return (
methodStub.ContainingSyntaxContext.AddContainingSyntax(
Expand Down Expand Up @@ -518,7 +510,7 @@ private static ImmutableArray<TypePositionInfo> AddImplicitElementInfos(Incremen

var elements = ImmutableArray.CreateBuilder<TypePositionInfo>(originalElements.Length + 2);

elements.Add(new TypePositionInfo(methodStub.TypeKeyOwner, new NativeThisInfo(methodStub.TypeKeyType))
elements.Add(new TypePositionInfo(methodStub.TypeKeyOwner, NativeThisInfo.Instance)
{
InstanceIdentifier = ThisParameterIdentifier,
NativeIndex = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static class TypeNames

public const string IUnmanagedVirtualMethodTableProvider = "System.Runtime.InteropServices.IUnmanagedVirtualMethodTableProvider";

public const string IUnmanagedInterfaceType_Metadata = "System.Runtime.InteropServices.IUnmanagedInterfaceType`2";
public const string IUnmanagedInterfaceType_Metadata = "System.Runtime.InteropServices.IUnmanagedInterfaceType`1";

public const string System_Span_Metadata = "System.Span`1";
public const string System_Span = "System.Span";
Expand Down
Loading

0 comments on commit 170587e

Please sign in to comment.