Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip interfaces not publicly accessible in authoring scenarios #1394

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ private void CheckDeclarations()
foreach (var declaration in syntaxReceiver.Declarations)
{
var model = _context.Compilation.GetSemanticModel(declaration.SyntaxTree);
var symbol = model.GetDeclaredSymbol(declaration);

// Check symbol information for whether it is public to properly detect partial types
// which can leave out modifier.
if (model.GetDeclaredSymbol(declaration).DeclaredAccessibility != Accessibility.Public)
// which can leave out modifier. Also ignore nested types not effectively public
if (symbol.DeclaredAccessibility != Accessibility.Public ||
(symbol is ITypeSymbol typeSymbol && !typeSymbol.IsPubliclyAccessible()))
{
continue;
}
Expand Down
54 changes: 54 additions & 0 deletions src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;

#nullable enable

namespace Generator;

/// <summary>
/// Extensions for symbol types.
/// </summary>
internal static class SymbolExtensions
{
/// <summary>
/// Checks whether a given type symbol is publicly accessible (ie. it's public and not nested in any non public type).
/// </summary>
/// <param name="type">The type symbol to check for public accessibility.</param>
/// <returns>Whether <paramref name="type"/> is publicly accessible.</returns>
public static bool IsPubliclyAccessible(this ITypeSymbol type)
{
for (ITypeSymbol? currentType = type; currentType is not null; currentType = currentType.ContainingType)
{
// If any type in the type hierarchy is not public, the type is not public.
// This makes sure to detect public types nested into eg. a private type.
if (currentType.DeclaredAccessibility is not Accessibility.Public)
{
return false;
}
}

return true;
}

/// <summary>
/// Checks whether a given symbol is an explicit interface implementation of a member of an internal interface (or more than one).
/// </summary>
/// <param name="symbol">The input member symbol to check.</param>
/// <returns>Whether <paramref name="symbol"/> is an explicit interface implementation of internal interfaces.</returns>
public static bool IsExplicitInterfaceImplementationOfInternalInterfaces(this ISymbol symbol)
{
static bool IsAnyContainingTypePublic(IEnumerable<ISymbol> symbols)
{
return symbols.Any(static symbol => symbol.ContainingType!.IsPubliclyAccessible());
}

return symbol switch
{
IMethodSymbol { ExplicitInterfaceImplementations: { Length: > 0 } methods } => !IsAnyContainingTypePublic(methods),
IPropertySymbol { ExplicitInterfaceImplementations: { Length: > 0 } properties } => !IsAnyContainingTypePublic(properties),
IEventSymbol { ExplicitInterfaceImplementations: { Length: > 0 } events } => !IsAnyContainingTypePublic(events),
_ => false
};
}
}
59 changes: 45 additions & 14 deletions src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,21 +1204,38 @@ Symbol GetType(string type, bool isGeneric = false, int genericIndex = -1, bool

private IEnumerable<INamedTypeSymbol> GetInterfaces(INamedTypeSymbol symbol, bool includeInterfacesWithoutMappings = false)
{
HashSet<INamedTypeSymbol> interfaces = new HashSet<INamedTypeSymbol>();
foreach (var @interface in symbol.Interfaces)
HashSet<INamedTypeSymbol> interfaces = new();

// Gather all interfaces that are publicly accessible. We specifically need to exclude interfaces
// that are not public, as eg. those might be used for additional cloaked WinRT/COM interfaces.
// Ignoring them here makes sure that they're not processed to be part of the .winmd file.
void GatherPubliclyAccessibleInterfaces(ITypeSymbol symbol)
{
interfaces.Add(@interface);
interfaces.UnionWith(@interface.AllInterfaces);
foreach (var @interface in symbol.Interfaces)
{
if (@interface.IsPubliclyAccessible())
{
_ = interfaces.Add(@interface);
}

// We're not using AllInterfaces on purpose: we only want to gather all interfaces but not
// from the base type. That's handled below to skip types that are already WinRT projections.
foreach (var @interface2 in @interface.AllInterfaces)
{
if (@interface2.IsPubliclyAccessible())
{
_ = interfaces.Add(@interface2);
}
}
}
}

GatherPubliclyAccessibleInterfaces(symbol);

var baseType = symbol.BaseType;
while (baseType != null && !GeneratorHelper.IsWinRTType(baseType))
{
interfaces.UnionWith(baseType.Interfaces);
foreach (var @interface in baseType.Interfaces)
{
interfaces.UnionWith(@interface.AllInterfaces);
}
GatherPubliclyAccessibleInterfaces(baseType);

baseType = baseType.BaseType;
}
Expand Down Expand Up @@ -1911,6 +1928,13 @@ void AddComponentType(INamedTypeSymbol type, Action visitTypeDeclaration = null)
}
else
{
// Special case: skip members that are explicitly implementing internal interfaces.
// This allows implementing classic COM internal interfaces with non-WinRT signatures.
if (member.IsExplicitInterfaceImplementationOfInternalInterfaces())
{
continue;
}

if (member is IMethodSymbol method &&
(method.MethodKind == MethodKind.Ordinary ||
method.MethodKind == MethodKind.ExplicitInterfaceImplementation ||
Expand Down Expand Up @@ -2683,12 +2707,19 @@ typeDeclaration.Node is INamedTypeSymbol symbol &&
}
}

public bool IsPublic(ISymbol type)
public bool IsPublic(ISymbol symbol)
{
return type.DeclaredAccessibility == Accessibility.Public ||
type is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
type is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
type is IEventSymbol @event && [email protected];
// Check that the type has either public accessibility, or is an explicit interface implementation
if (symbol.DeclaredAccessibility == Accessibility.Public ||
symbol is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
symbol is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
symbol is IEventSymbol @event && [email protected])
{
// If we have a containing type, we also check that it's publicly accessible
return symbol.ContainingType is not { } containingType || containingType.IsPubliclyAccessible();
Copy link
Member

@manodasanW manodasanW Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to understand this, should this have been an && rather than a || or am I misunderstanding what is not {} does.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, is not { } matches if the expression is not an instance (it's like is null, but also declares a local). So:

  • If containingType is null, return true
  • Otherwise (ie. containingType != null), return true if it's publicly accessible

If we used &&, we'd be trying to access containingType when the first expression already matched, but if that's the case, then containingType is not defined (because the expression matches on is not). In fact, if you changed it to &&, the code would just not compile, as containingType would be uninitialized in that case 🙂

}

return false;
}

public void GetNamespaceAndTypename(string qualifiedName, out string @namespace, out string typename)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,9 @@
name="AuthoringTest.TestClass"
threadingModel="both"
xmlns="urn:schemas-microsoft-com:winrt.v1" />
<activatableClass
name="AuthoringTest.TestMixedWinRTCOMWrapper"
threadingModel="both"
xmlns="urn:schemas-microsoft-com:winrt.v1" />
</file>
</assembly>
1 change: 1 addition & 0 deletions src/Tests/AuthoringConsumptionTest/pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// conflict with Storyboard::GetCurrentTime
#undef GetCurrentTime

#include <Windows.h>
#include <winrt/Windows.Foundation.h>
#include <winrt/Windows.Foundation.Collections.h>

Expand Down
34 changes: 34 additions & 0 deletions src/Tests/AuthoringConsumptionTest/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,4 +639,38 @@ TEST(AuthoringTest, PartialClass)
EXPECT_EQ(partialStruct.X, 3);
EXPECT_EQ(partialStruct.Y, 4);
EXPECT_EQ(partialStruct.Z, 5);
}

TEST(AuthoringTest, MixedWinRTClassicCOM)
{
TestMixedWinRTCOMWrapper wrapper;

// Normal WinRT methods work as you'd expect
EXPECT_EQ(wrapper.HelloWorld(), L"Hello from mixed WinRT/COM");

// Verify we can grab the internal interface
IID internalInterface1Iid;
check_hresult(IIDFromString(L"{C7850559-8FF2-4E54-A237-6ED813F20CDC}", &internalInterface1Iid));
winrt::com_ptr<::IUnknown> unknown1 = wrapper.as<::IUnknown>();
winrt::com_ptr<::IUnknown> internalInterface1;
EXPECT_EQ(unknown1->QueryInterface(internalInterface1Iid, internalInterface1.put_void()), S_OK);

// Verify we can grab the nested public interface (in an internal type)
IID internalInterface2Iid;
check_hresult(IIDFromString(L"{8A08E18A-8D20-4E7C-9242-857BFE1E3159}", &internalInterface2Iid));
winrt::com_ptr<::IUnknown> unknown2 = wrapper.as<::IUnknown>();
winrt::com_ptr<::IUnknown> internalInterface2;
EXPECT_EQ(unknown2->QueryInterface(internalInterface2Iid, internalInterface2.put_void()), S_OK);

typedef int (__stdcall* GetNumber)(void*, int*);

int number;

// Validate the first call on IInternalInterface1
EXPECT_EQ(reinterpret_cast<GetNumber>((*reinterpret_cast<void***>(internalInterface1.get()))[3])(internalInterface1.get(), &number), S_OK);
EXPECT_EQ(number, 42);

// Validate the second call on IInternalInterface2
EXPECT_EQ(reinterpret_cast<GetNumber>((*reinterpret_cast<void***>(internalInterface2.get()))[3])(internalInterface2.get(), &number), S_OK);
EXPECT_EQ(number, 123);
}
141 changes: 141 additions & 0 deletions src/Tests/AuthoringTest/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Windows.Input;
using Windows.Foundation;
using Windows.Foundation.Collections;
using Windows.Foundation.Metadata;
using Windows.Graphics.Effects;
using WinRT;
using WinRT.Interop;

#pragma warning disable CA1416

Expand Down Expand Up @@ -1569,4 +1574,140 @@ public partial struct PartialStruct
{
public double Z;
}

public sealed class TestMixedWinRTCOMWrapper : IGraphicsEffectSource, IPublicInterface, IInternalInterface1, SomeInternalType.IInternalInterface2
{
public string HelloWorld()
{
return "Hello from mixed WinRT/COM";
}

unsafe int IInternalInterface1.GetNumber(int* value)
{
*value = 42;

return 0;
}

unsafe int SomeInternalType.IInternalInterface2.GetNumber(int* value)
{
*value = 123;

return 0;
}
}

public interface IPublicInterface
{
string HelloWorld();
}

// Internal, classic COM interface
[global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")]
[WindowsRuntimeType]
[WindowsRuntimeHelperType(typeof(IInternalInterface1))]
internal unsafe interface IInternalInterface1
{
int GetNumber(int* value);

[global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")]
public struct Vftbl
{
public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl();

private static IntPtr InitVtbl()
{
Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl));

lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl;
lpVtbl->GetNumber = &GetDeviceFromAbi;

return (IntPtr)lpVtbl;
}

private IUnknownVftbl IUnknownVftbl;
private delegate* unmanaged[Stdcall]<void*, int*, int> GetNumber;

[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static int GetDeviceFromAbi(void* thisPtr, int* value)
Sergio0694 marked this conversation as resolved.
Show resolved Hide resolved
{
try
{
return ComWrappersSupport.FindObject<IInternalInterface1>((IntPtr)thisPtr).GetNumber(value);
}
catch (Exception e)
{
ExceptionHelpers.SetErrorInfo(e);

return Marshal.GetHRForException(e);
}
}
}
}

internal struct SomeInternalType
{
// Nested, classic COM interface
[global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")]
[WindowsRuntimeType]
[WindowsRuntimeHelperType(typeof(IInternalInterface2))]
public unsafe interface IInternalInterface2
{
int GetNumber(int* value);

[global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")]
public struct Vftbl
{
public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl();

private static IntPtr InitVtbl()
{
Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl));

lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl;
lpVtbl->GetNumber = &GetDeviceFromAbi;

return (IntPtr)lpVtbl;
}

private IUnknownVftbl IUnknownVftbl;
private delegate* unmanaged[Stdcall]<void*, int*, int> GetNumber;

[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static int GetDeviceFromAbi(void* thisPtr, int* value)
{
try
{
return ComWrappersSupport.FindObject<IInternalInterface2>((IntPtr)thisPtr).GetNumber(value);
}
catch (Exception e)
{
ExceptionHelpers.SetErrorInfo(e);

return Marshal.GetHRForException(e);
}
}
}
}
}
}

namespace ABI.AuthoringTest
{
internal static class IInternalInterface1Methods
{
public static Guid IID => typeof(global::AuthoringTest.IInternalInterface1).GUID;

public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.IInternalInterface1.Vftbl.AbiToProjectionVftablePtr;
}

internal struct SomeInternalType
{
internal static class IInternalInterface2Methods
{
public static Guid IID => typeof(global::AuthoringTest.SomeInternalType.IInternalInterface2).GUID;

public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.SomeInternalType.IInternalInterface2.Vftbl.AbiToProjectionVftablePtr;
}
}
}
Loading