Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid caching compilation data and use value equality for SyntaxNodes #79051

Merged
merged 7 commits into from
Jan 3, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Microsoft.Interop
/// <summary>
/// VirtualMethodIndexAttribute data
/// </summary>
internal sealed record VirtualMethodIndexData(int Index) : InteropAttributeData
internal sealed record VirtualMethodIndexData(int Index) : InteropAttributeCompilationData
{
public bool ImplicitThisParameter { get; init; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.Interop
{
internal static class DefaultMarshallingInfoParser
{
public static MarshallingInfoParser Create(StubEnvironment env, IGeneratorDiagnostics diagnostics, IMethodSymbol method, InteropAttributeData interopAttributeData, AttributeData unparsedAttributeData)
public static MarshallingInfoParser Create(StubEnvironment env, IGeneratorDiagnostics diagnostics, IMethodSymbol method, InteropAttributeCompilationData interopAttributeData, AttributeData unparsedAttributeData)
{

// Compute the current default string encoding value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ private static bool HasUnsupportedMarshalAsInfo(TypePositionInfo info)
|| unmanagedType == UnmanagedType.SafeArray;
}

private static InteropAttributeData CreateInteropAttributeDataFromDllImport(DllImportData dllImportData)
private static InteropAttributeCompilationData CreateInteropAttributeDataFromDllImport(DllImportData dllImportData)
{
InteropAttributeData interopData = new();
InteropAttributeCompilationData interopData = new();
if (dllImportData.SetLastError)
{
interopData = interopData with { IsUserDefined = interopData.IsUserDefined | InteropAttributeMember.SetLastError, SetLastError = true };
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
{
/// <summary>
/// LibraryImportAttribute data
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
internal sealed record LibraryImportData(string ModuleName) : InteropAttributeData
internal sealed record LibraryImportCompilationData(string ModuleName) : InteropAttributeCompilationData
{
public string EntryPoint { get; init; }
}

internal sealed record LibraryImportData(string ModuleName) : InteropAttributeModelData
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
{
public string EntryPoint { get; init; }

public static LibraryImportData From(LibraryImportCompilationData libraryImport)
=> new LibraryImportData(libraryImport.ModuleName) with
{
EntryPoint = libraryImport.EntryPoint,
IsUserDefined = libraryImport.IsUserDefined,
SetLastError = libraryImport.SetLastError,
StringMarshalling = libraryImport.StringMarshalling
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -181,7 +180,7 @@ private static MemberDeclarationSyntax PrintGeneratedSource(
.WithBody(stubCode);
}

private static LibraryImportData? ProcessLibraryImportAttribute(AttributeData attrData)
private static LibraryImportCompilationData? ProcessLibraryImportAttribute(AttributeData attrData)
{
// Found the LibraryImport, but it has an error so report the error.
// This is most likely an issue with targeting an incorrect TFM.
Expand All @@ -198,7 +197,7 @@ private static MemberDeclarationSyntax PrintGeneratedSource(
ImmutableDictionary<string, TypedConstant> namedArguments = ImmutableDictionary.CreateRange(attrData.NamedArguments);

string? entryPoint = null;
if (namedArguments.TryGetValue(nameof(LibraryImportData.EntryPoint), out TypedConstant entryPointValue))
if (namedArguments.TryGetValue(nameof(LibraryImportCompilationData.EntryPoint), out TypedConstant entryPointValue))
{
if (entryPointValue.Value is not string)
{
Expand All @@ -207,7 +206,7 @@ private static MemberDeclarationSyntax PrintGeneratedSource(
entryPoint = (string)entryPointValue.Value!;
}

return new LibraryImportData(attrData.ConstructorArguments[0].Value!.ToString())
return new LibraryImportCompilationData(attrData.ConstructorArguments[0].Value!.ToString())
{
EntryPoint = entryPoint,
}.WithValuesFromNamedArguments(namedArguments);
Expand Down Expand Up @@ -261,9 +260,9 @@ private static IncrementalStubGenerationContext CalculateStubInformation(
var generatorDiagnostics = new GeneratorDiagnostics();

// Process the LibraryImport attribute
LibraryImportData libraryImportData =
LibraryImportCompilationData libraryImportData =
ProcessLibraryImportAttribute(generatedDllImportAttr!) ??
new LibraryImportData("INVALID_CSHARP_SYNTAX");
new LibraryImportCompilationData("INVALID_CSHARP_SYNTAX");

if (libraryImportData.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling))
{
Expand Down Expand Up @@ -302,7 +301,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(
methodSyntaxTemplate,
new MethodSignatureDiagnosticLocations(originalSyntax),
new SequenceEqualImmutableArray<AttributeSyntax>(additionalAttributes.ToImmutableArray(), SyntaxEquivalentComparer.Instance),
libraryImportData,
LibraryImportData.From(libraryImportData),
LibraryImportGeneratorHelpers.CreateGeneratorFactory(environment, options),
new SequenceEqualImmutableArray<Diagnostic>(generatorDiagnostics.Diagnostics.ToImmutableArray())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Immutable;
using System.Runtime.InteropServices;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
Expand All @@ -24,27 +23,30 @@ public enum InteropAttributeMember
/// <summary>
/// Common data for all source-generated-interop trigger attributes
/// </summary>
public record InteropAttributeData
public record InteropAttributeCompilationData : InteropAttributeModelData
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
{
public INamedTypeSymbol? StringMarshallingCustomType { get; init; }
}
public record InteropAttributeModelData
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
{
/// <summary>
/// Value set by the user on the original declaration.
/// </summary>
public InteropAttributeMember IsUserDefined { get; init; }
public bool SetLastError { get; init; }
public StringMarshalling StringMarshalling { get; init; }
public INamedTypeSymbol? StringMarshallingCustomType { get; init; }
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add remarks to the InteropAttributeData doc about how it should not include any symbols / types that would keep a compilation alive.


public static class InteropAttributeDataExtensions
{
public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<string, TypedConstant> namedArguments) where T : InteropAttributeData
public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<string, TypedConstant> namedArguments) where T : InteropAttributeCompilationData
{
InteropAttributeMember userDefinedValues = InteropAttributeMember.None;
bool setLastError = false;
StringMarshalling stringMarshalling = StringMarshalling.Custom;
INamedTypeSymbol? stringMarshallingCustomType = null;

if (namedArguments.TryGetValue(nameof(InteropAttributeData.SetLastError), out TypedConstant setLastErrorValue))
if (namedArguments.TryGetValue(nameof(InteropAttributeCompilationData.SetLastError), out TypedConstant setLastErrorValue))
{
userDefinedValues |= InteropAttributeMember.SetLastError;
if (setLastErrorValue.Value is not bool)
Expand All @@ -53,7 +55,7 @@ public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<st
}
setLastError = (bool)setLastErrorValue.Value!;
}
if (namedArguments.TryGetValue(nameof(InteropAttributeData.StringMarshalling), out TypedConstant stringMarshallingValue))
if (namedArguments.TryGetValue(nameof(InteropAttributeCompilationData.StringMarshalling), out TypedConstant stringMarshallingValue))
{
userDefinedValues |= InteropAttributeMember.StringMarshalling;
// TypedConstant's Value property only contains primitive values.
Expand All @@ -64,7 +66,7 @@ public static T WithValuesFromNamedArguments<T>(this T t, ImmutableDictionary<st
// A boxed primitive can be unboxed to an enum with the same underlying type.
stringMarshalling = (StringMarshalling)stringMarshallingValue.Value!;
}
if (namedArguments.TryGetValue(nameof(InteropAttributeData.StringMarshallingCustomType), out TypedConstant stringMarshallingCustomTypeValue))
if (namedArguments.TryGetValue(nameof(InteropAttributeCompilationData.StringMarshallingCustomType), out TypedConstant stringMarshallingCustomTypeValue))
{
userDefinedValues |= InteropAttributeMember.StringMarshallingCustomType;
if (stringMarshallingCustomTypeValue.Value is not INamedTypeSymbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.Interop
{
Expand All @@ -18,6 +15,19 @@ public abstract record ManagedTypeInfo(string FullTypeName, string DiagnosticFor
private TypeSyntax? _syntax;
public TypeSyntax Syntax => _syntax ??= SyntaxFactory.ParseTypeName(FullTypeName);

public virtual bool Equals(ManagedTypeInfo other)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ManagedTypeInfo? be the parameter? Asking because the implementation checks for null.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be a null check, thanks for pointing that out!

{
return other is not null
&& Syntax.IsEquivalentTo(other.Syntax)
&& FullTypeName == other.FullTypeName
&& DiagnosticFormattedName == other.DiagnosticFormattedName;
}

public override int GetHashCode()
{
return Syntax.GetHashCode() ^ FullTypeName.GetHashCode() ^ DiagnosticFormattedName.GetHashCode();
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
}

protected ManagedTypeInfo(ManagedTypeInfo original)
{
FullTypeName = original.FullTypeName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ public readonly record struct CustomTypeMarshallerData(
public readonly record struct CustomTypeMarshallers(
ImmutableDictionary<MarshalMode, CustomTypeMarshallerData> Modes)
{
public bool Equals(CustomTypeMarshallers other)
{
return Modes.Count == other.Modes.Count
&& !Modes.Except(other.Modes).Any();
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
}

public override int GetHashCode()
{
int hash = 0;
foreach (KeyValuePair<MarshalMode, CustomTypeMarshallerData> mode in Modes)
{
hash ^= mode.Key.GetHashCode() ^ mode.Value.GetHashCode();
}
return hash;
}

public CustomTypeMarshallerData GetModeOrDefault(MarshalMode mode)
{
CustomTypeMarshallerData data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ public static partial void Method(
+ CustomCollectionMarshallingCodeSnippets<CodeSnippets>.Stateless.In
+ CustomCollectionMarshallingCodeSnippets<CodeSnippets>.CustomIntMarshaller;

public static string RecursiveCountElementNameOnReturnValue => $@"
public static string RecursiveCountElementNameOnReturnValue => $@"
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
{DisableRuntimeMarshalling}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Text;
using Microsoft.Interop.UnitTests;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.Interop.UnitTests;
using Xunit;
using static Microsoft.Interop.LibraryImportGenerator;

Expand Down Expand Up @@ -212,13 +210,20 @@ public async Task ChangingMarshallingAttributes_SameStrategy_DoesNotRegenerate()
});
}

public static IEnumerable<object[]> CompilationObjectLivenessSources()
{
// Basic stub
yield return new[] { CodeSnippets.BasicParametersAndModifiers<int>() };
// Stub with custom string marshaller
yield return new[] { CodeSnippets.CustomStringMarshallingParametersAndModifiers<string>() };
}

// This test requires precise GC to ensure that we're accurately testing that we aren't
// keeping the Compilation alive.
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPreciseGcSupported))]
public async Task GeneratorRun_WithNewCompilation_DoesNotKeepOldCompilationAlive()
[MemberData(nameof(CompilationObjectLivenessSources))]
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsPreciseGcSupported))]
public async Task GeneratorRun_WithNewCompilation_DoesNotKeepOldCompilationAlive(string source)
{
string source = $"namespace NS{{{CodeSnippets.BasicParametersAndModifiers<int>()}}}";

SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview));

Compilation comp1 = await TestUtils.CreateCompilation(new[] { syntaxTree });
Expand Down