Skip to content

Commit

Permalink
Avoid caching compilation data and use value equality for SyntaxNodes (
Browse files Browse the repository at this point in the history
…#79051)

Fixes #78242

Creates separate types for InteropAttributeData: one that holds compilation data (InteropAttributeCompilationData), and one that holds only the data necessary for the model to create the generated code (InteropAttributeModelData).

This uncovered some issues with record equality in records that use SyntaxNode. For those, we need to override Equals or wrap the SyntaxNode in a type that overrides Equals to use IsEquivalentTo on the SyntaxNode.

There are probably more places where we use SyntaxNode that aren't caught in the current tests.

To make sure every record has the right equality, I wasn't sure if it would be better to override Equals for each of the records, or create a wrapper record struct for each SyntaxNode that implements the equality we want (and implicit casts to and from the SyntaxNode). Then we wouldn't have to explicitly override the equality in each record that has a SyntaxNode. I also overrode both Equals and GetHashCode, but I'm not confident in my GetHashCode implementation. It could also be done with IEquatable.Equals without needing GetHashCode, but that would require implementing the TypeSyntax equality for every type that inherits from ManagedTypeInfo.
  • Loading branch information
jtschuster authored Jan 3, 2023
1 parent 795fcec commit 55e0dea
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.Interop
/// <summary>
/// VirtualMethodIndexAttribute data
/// </summary>
internal sealed record VirtualMethodIndexData(int Index) : InteropAttributeData
internal sealed record VirtualMethodIndexData(int Index) : InteropAttributeCompilationData
{
public bool ImplicitThisParameter { get; init; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.Interop
{
internal static class DefaultMarshallingInfoParser
{
public static MarshallingInfoParser Create(StubEnvironment env, IGeneratorDiagnostics diagnostics, IMethodSymbol method, InteropAttributeData interopAttributeData, AttributeData unparsedAttributeData)
public static MarshallingInfoParser Create(StubEnvironment env, IGeneratorDiagnostics diagnostics, IMethodSymbol method, InteropAttributeCompilationData interopAttributeData, AttributeData unparsedAttributeData)
{

// Compute the current default string encoding value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ private static bool HasUnsupportedMarshalAsInfo(TypePositionInfo info)
|| unmanagedType == UnmanagedType.SafeArray;
}

private static InteropAttributeData CreateInteropAttributeDataFromDllImport(DllImportData dllImportData)
private static InteropAttributeCompilationData CreateInteropAttributeDataFromDllImport(DllImportData dllImportData)
{
InteropAttributeData interopData = new();
InteropAttributeCompilationData interopData = new();
if (dllImportData.SetLastError)
{
interopData = interopData with { IsUserDefined = interopData.IsUserDefined | InteropAttributeMember.SetLastError, SetLastError = true };
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
{
/// <summary>
/// LibraryImportAttribute data
/// Contains the data related to a LibraryImportAttribute, without references to Roslyn symbols.
/// See <seealso cref="LibraryImportCompilationData"/> for a type with a reference to the StringMarshallingCustomType
/// </summary>
internal sealed record LibraryImportData(string ModuleName) : InteropAttributeData
{
public string EntryPoint { get; init; }

public static LibraryImportData From(LibraryImportCompilationData libraryImport)
=> new LibraryImportData(libraryImport.ModuleName) with
{
EntryPoint = libraryImport.EntryPoint,
IsUserDefined = libraryImport.IsUserDefined,
SetLastError = libraryImport.SetLastError,
StringMarshalling = libraryImport.StringMarshalling
};
}

/// <summary>
/// Contains the data related to a LibraryImportAttribute, with references to Roslyn symbols.
/// Use <seealso cref="LibraryImportData"/> instead when using for incremental compilation state to avoid keeping a compilation alive
/// </summary>
internal sealed record LibraryImportCompilationData(string ModuleName) : InteropAttributeCompilationData
{
public string EntryPoint { get; init; }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -181,7 +180,7 @@ private static MemberDeclarationSyntax PrintGeneratedSource(
.WithBody(stubCode);
}

private static LibraryImportData? ProcessLibraryImportAttribute(AttributeData attrData)
private static LibraryImportCompilationData? ProcessLibraryImportAttribute(AttributeData attrData)
{
// Found the LibraryImport, but it has an error so report the error.
// This is most likely an issue with targeting an incorrect TFM.
Expand All @@ -198,7 +197,7 @@ private static MemberDeclarationSyntax PrintGeneratedSource(
ImmutableDictionary<string, TypedConstant> namedArguments = ImmutableDictionary.CreateRange(attrData.NamedArguments);

string? entryPoint = null;
if (namedArguments.TryGetValue(nameof(LibraryImportData.EntryPoint), out TypedConstant entryPointValue))
if (namedArguments.TryGetValue(nameof(LibraryImportCompilationData.EntryPoint), out TypedConstant entryPointValue))
{
if (entryPointValue.Value is not string)
{
Expand All @@ -207,7 +206,7 @@ private static MemberDeclarationSyntax PrintGeneratedSource(
entryPoint = (string)entryPointValue.Value!;
}

return new LibraryImportData(attrData.ConstructorArguments[0].Value!.ToString())
return new LibraryImportCompilationData(attrData.ConstructorArguments[0].Value!.ToString())
{
EntryPoint = entryPoint,
}.WithValuesFromNamedArguments(namedArguments);
Expand Down Expand Up @@ -261,9 +260,9 @@ private static IncrementalStubGenerationContext CalculateStubInformation(
var generatorDiagnostics = new GeneratorDiagnostics();

// Process the LibraryImport attribute
LibraryImportData libraryImportData =
LibraryImportCompilationData libraryImportData =
ProcessLibraryImportAttribute(generatedDllImportAttr!) ??
new LibraryImportData("INVALID_CSHARP_SYNTAX");
new LibraryImportCompilationData("INVALID_CSHARP_SYNTAX");

if (libraryImportData.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling))
{
Expand Down Expand Up @@ -302,7 +301,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(
methodSyntaxTemplate,
new MethodSignatureDiagnosticLocations(originalSyntax),
new SequenceEqualImmutableArray<AttributeSyntax>(additionalAttributes.ToImmutableArray(), SyntaxEquivalentComparer.Instance),
libraryImportData,
LibraryImportData.From(libraryImportData),
LibraryImportGeneratorHelpers.CreateGeneratorFactory(environment, options),
new SequenceEqualImmutableArray<Diagnostic>(generatorDiagnostics.Diagnostics.ToImmutableArray())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Immutable;
using System.Runtime.InteropServices;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
Expand All @@ -22,9 +21,24 @@ public enum InteropAttributeMember
}

/// <summary>
/// Common data for all source-generated-interop trigger attributes
/// Common data for all source-generated-interop trigger attributes.
/// This type and derived types should not have any reference that would keep a compilation alive.
/// </summary>
public record InteropAttributeData
{
/// <summary>
/// Value set by the user on the original declaration.
/// </summary>
public InteropAttributeMember IsUserDefined { get; init; }
public bool SetLastError { get; init; }
public StringMarshalling StringMarshalling { get; init; }
}

/// <summary>
/// Common data for all source-generated-interop trigger attributes that also includes a reference to the Roslyn symbol for StringMarshallingCustomType.
/// See <seealso cref="InteropAttributeData"/> for a type that doesn't keep a compilation alive.
/// </summary>
public record InteropAttributeCompilationData
{
/// <summary>
/// Value set by the user on the original declaration.
Expand All @@ -37,14 +51,14 @@ public record InteropAttributeData

public static class InteropAttributeDataExtensions
{
public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<string, TypedConstant> namedArguments) where T : InteropAttributeData
public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<string, TypedConstant> namedArguments) where T : InteropAttributeCompilationData
{
InteropAttributeMember userDefinedValues = InteropAttributeMember.None;
bool setLastError = false;
StringMarshalling stringMarshalling = StringMarshalling.Custom;
INamedTypeSymbol? stringMarshallingCustomType = null;

if (namedArguments.TryGetValue(nameof(InteropAttributeData.SetLastError), out TypedConstant setLastErrorValue))
if (namedArguments.TryGetValue(nameof(InteropAttributeCompilationData.SetLastError), out TypedConstant setLastErrorValue))
{
userDefinedValues |= InteropAttributeMember.SetLastError;
if (setLastErrorValue.Value is not bool)
Expand All @@ -53,7 +67,7 @@ public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<st
}
setLastError = (bool)setLastErrorValue.Value!;
}
if (namedArguments.TryGetValue(nameof(InteropAttributeData.StringMarshalling), out TypedConstant stringMarshallingValue))
if (namedArguments.TryGetValue(nameof(InteropAttributeCompilationData.StringMarshalling), out TypedConstant stringMarshallingValue))
{
userDefinedValues |= InteropAttributeMember.StringMarshalling;
// TypedConstant's Value property only contains primitive values.
Expand All @@ -64,7 +78,7 @@ public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<st
// A boxed primitive can be unboxed to an enum with the same underlying type.
stringMarshalling = (StringMarshalling)stringMarshallingValue.Value!;
}
if (namedArguments.TryGetValue(nameof(InteropAttributeData.StringMarshallingCustomType), out TypedConstant stringMarshallingCustomTypeValue))
if (namedArguments.TryGetValue(nameof(InteropAttributeCompilationData.StringMarshallingCustomType), out TypedConstant stringMarshallingCustomTypeValue))
{
userDefinedValues |= InteropAttributeMember.StringMarshallingCustomType;
if (stringMarshallingCustomTypeValue.Value is not INamedTypeSymbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.Interop
{
Expand All @@ -18,6 +15,19 @@ public abstract record ManagedTypeInfo(string FullTypeName, string DiagnosticFor
private TypeSyntax? _syntax;
public TypeSyntax Syntax => _syntax ??= SyntaxFactory.ParseTypeName(FullTypeName);

public virtual bool Equals(ManagedTypeInfo? other)
{
return other is not null
&& Syntax.IsEquivalentTo(other.Syntax)
&& FullTypeName == other.FullTypeName
&& DiagnosticFormattedName == other.DiagnosticFormattedName;
}

public override int GetHashCode()
{
return FullTypeName.GetHashCode() ^ DiagnosticFormattedName.GetHashCode();
}

protected ManagedTypeInfo(ManagedTypeInfo original)
{
FullTypeName = original.FullTypeName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ public readonly record struct CustomTypeMarshallerData(
public readonly record struct CustomTypeMarshallers(
ImmutableDictionary<MarshalMode, CustomTypeMarshallerData> Modes)
{
public bool Equals(CustomTypeMarshallers other)
{
// Check for equal count, then check if any KeyValuePairs exist in one 'Modes'
// but not the other (i.e. set equality on the set of items in the dictionary)
return Modes.Count == other.Modes.Count
&& !Modes.Except(other.Modes).Any();
}

public override int GetHashCode()
{
int hash = 0;
foreach (KeyValuePair<MarshalMode, CustomTypeMarshallerData> mode in Modes)
{
hash ^= mode.Key.GetHashCode() ^ mode.Value.GetHashCode();
}
return hash;
}

public CustomTypeMarshallerData GetModeOrDefault(MarshalMode mode)
{
CustomTypeMarshallerData data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ public static partial void Method(
+ CustomCollectionMarshallingCodeSnippets<CodeSnippets>.Stateless.In
+ CustomCollectionMarshallingCodeSnippets<CodeSnippets>.CustomIntMarshaller;

public static string RecursiveCountElementNameOnReturnValue => $@"
public static string RecursiveCountElementNameOnReturnValue => $@"
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
{DisableRuntimeMarshalling}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Text;
using Microsoft.Interop.UnitTests;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.Interop.UnitTests;
using Xunit;
using static Microsoft.Interop.LibraryImportGenerator;

Expand Down Expand Up @@ -217,8 +215,7 @@ public static IEnumerable<object[]> CompilationObjectLivenessSources()
// Basic stub
yield return new[] { CodeSnippets.BasicParametersAndModifiers<int>() };
// Stub with custom string marshaller
// TODO: Compilation is held alive by the CustomStringMarshallingType property in LibraryImportData
// yield return new[] { CodeSnippets.CustomStringMarshallingParametersAndModifiers<string>() };
yield return new[] { CodeSnippets.CustomStringMarshallingParametersAndModifiers<string>() };
}

// This test requires precise GC to ensure that we're accurately testing that we aren't
Expand Down

0 comments on commit 55e0dea

Please sign in to comment.