From 26cfb288bf73845d39a27eb759bca76d3c87d343 Mon Sep 17 00:00:00 2001 From: Tim M <49349513+TimothyMakkison@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:42:23 +0100 Subject: [PATCH] feat: add incremental generator (#1864) --- InterfaceStubGenerator.Shared/Emitter.cs | 414 ++++++++++ .../ImmutableEquatableArray.cs | 93 +++ .../IncrementalValuesProviderExtensions.cs | 68 ++ .../InterfaceStubGenerator.Shared.projitems | 10 + .../InterfaceStubGenerator.cs | 768 +----------------- .../IsExternalInit.cs | 20 + .../Models/ContextGenerationModel.cs | 7 + .../Models/InterfaceModel.cs | 25 + .../Models/MethodModel.cs | 18 + .../Models/ParameterModel.cs | 8 + .../Models/TypeConstraint.cs | 19 + InterfaceStubGenerator.Shared/Parser.cs | 520 ++++++++++++ .../Incremental/FunctionTest.cs | 62 +- .../Incremental/GenericTest.cs | 8 +- .../IncrementalGeneratorRunReasons.cs | 2 +- .../Incremental/IncrementalTest.cs | 7 +- .../Incremental/InheritanceTest.cs | 7 +- Refit.GeneratorTests/InterfaceTests.cs | 2 +- 18 files changed, 1314 insertions(+), 744 deletions(-) create mode 100644 InterfaceStubGenerator.Shared/Emitter.cs create mode 100644 InterfaceStubGenerator.Shared/ImmutableEquatableArray.cs create mode 100644 InterfaceStubGenerator.Shared/IncrementalValuesProviderExtensions.cs create mode 100644 InterfaceStubGenerator.Shared/IsExternalInit.cs create mode 100644 InterfaceStubGenerator.Shared/Models/ContextGenerationModel.cs create mode 100644 InterfaceStubGenerator.Shared/Models/InterfaceModel.cs create mode 100644 InterfaceStubGenerator.Shared/Models/MethodModel.cs create mode 100644 InterfaceStubGenerator.Shared/Models/ParameterModel.cs create mode 100644 InterfaceStubGenerator.Shared/Models/TypeConstraint.cs create mode 100644 InterfaceStubGenerator.Shared/Parser.cs diff --git a/InterfaceStubGenerator.Shared/Emitter.cs b/InterfaceStubGenerator.Shared/Emitter.cs new file mode 100644 index 000000000..88b49e5fe --- /dev/null +++ b/InterfaceStubGenerator.Shared/Emitter.cs @@ -0,0 +1,414 @@ +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis.Text; + +namespace Refit.Generator; + +internal static class Emitter +{ + private const string TypeParameterVariableName = "______typeParameters"; + + public static void EmitSharedCode( + ContextGenerationModel model, + Action addSource + ) + { + if (model.Interfaces.Count == 0) + return; + + var attributeText = $$""" + + #pragma warning disable + namespace {{model.RefitInternalNamespace}} + { + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + [global::System.AttributeUsage (global::System.AttributeTargets.Class | global::System.AttributeTargets.Struct | global::System.AttributeTargets.Enum | global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Event | global::System.AttributeTargets.Interface | global::System.AttributeTargets.Delegate)] + sealed class PreserveAttribute : global::System.Attribute + { + // + // Fields + // + public bool AllMembers; + + public bool Conditional; + } + } + #pragma warning restore + + """; + // add the attribute text + addSource("PreserveAttribute.g.cs", SourceText.From(attributeText, Encoding.UTF8)); + + var generatedClassText = $$""" + + #pragma warning disable + namespace Refit.Implementation + { + + /// + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.DebuggerNonUserCode] + [{{model.PreserveAttributeDisplayName}}] + [global::System.Reflection.Obfuscation(Exclude=true)] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static partial class Generated + { + #if NET5_0_OR_GREATER + [System.Runtime.CompilerServices.ModuleInitializer] + [System.Diagnostics.CodeAnalysis.DynamicDependency(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All, typeof(global::Refit.Implementation.Generated))] + public static void Initialize() + { + } + #endif + } + } + #pragma warning restore + + """; + addSource("Generated.g.cs", SourceText.From(generatedClassText, Encoding.UTF8)); + } + + public static string EmitInterface(InterfaceModel model) + { + var source = new StringBuilder(); + + // if nullability is supported emit the nullable directive + if (model.Nullability != Nullability.None) + { + source.Append("#nullable "); + source.Append(model.Nullability == Nullability.Enabled ? "enable" : "disable"); + } + + source.Append( + $@" +#pragma warning disable +namespace Refit.Implementation +{{ + + partial class Generated + {{ + + /// + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.DebuggerNonUserCode] + [{model.PreserveAttributeDisplayName}] + [global::System.Reflection.Obfuscation(Exclude=true)] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + partial class {model.Ns}{model.ClassDeclaration} + : {model.InterfaceDisplayName}{GenerateConstraints(model.Constraints, false)} + + {{ + /// + public global::System.Net.Http.HttpClient Client {{ get; }} + readonly global::Refit.IRequestBuilder requestBuilder; + + /// + public {model.Ns}{model.ClassSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) + {{ + Client = client; + this.requestBuilder = requestBuilder; + }} +" + ); + + var memberNames = new HashSet(model.MemberNames); + + // Handle Refit Methods + foreach (var method in model.RefitMethods) + { + WriteRefitMethod(source, method, true, memberNames); + } + + foreach (var method in model.DerivedRefitMethods) + { + WriteRefitMethod(source, method, false, memberNames); + } + + // Handle non-refit Methods that aren't static or properties or have a method body + foreach (var method in model.NonRefitMethods) + { + WriteNonRefitMethod(source, method); + } + + // Handle Dispose + if (model.DisposeMethod) + { + WriteDisposableMethod(source); + } + + source.Append( + @" + } + } +} + +#pragma warning restore +" + ); + return source.ToString(); + } + + /// + /// Generates the body of the Refit method + /// + /// + /// + /// True if directly from the type we're generating for, false for methods found on base interfaces + /// Contains the unique member names in the interface scope. + private static void WriteRefitMethod( + StringBuilder source, + MethodModel methodModel, + bool isTopLevel, + HashSet memberNames + ) + { + var parameterTypesExpression = GenerateTypeParameterExpression( + source, + methodModel, + memberNames + ); + + var returnType = methodModel.ReturnType; + var (isAsync, @return, configureAwait) = methodModel.ReturnTypeMetadata switch + { + ReturnTypeInfo.AsyncVoid => (true, "await (", ").ConfigureAwait(false)"), + ReturnTypeInfo.AsyncResult => (true, "return await (", ").ConfigureAwait(false)"), + ReturnTypeInfo.Return => (false, "return ", ""), + _ + => throw new ArgumentOutOfRangeException( + nameof(methodModel.ReturnTypeMetadata), + methodModel.ReturnTypeMetadata, + "Unsupported value." + ) + }; + + WriteMethodOpening(source, methodModel, !isTopLevel, isAsync); + + // Build the list of args for the array + var argArray = methodModel + .Parameters.AsArray() + .Select(static param => $"@{param.MetadataName}") + .ToArray(); + + // List of generic arguments + var genericArray = methodModel + .Constraints.AsArray() + .Select(static typeParam => $"typeof({typeParam.DeclaredName})") + .ToArray(); + + var argumentsArrayString = + argArray.Length == 0 + ? "global::System.Array.Empty()" + : $"new object[] {{ {string.Join(", ", argArray)} }}"; + + var genericString = + genericArray.Length > 0 + ? $", new global::System.Type[] {{ {string.Join(", ", genericArray)} }}" + : string.Empty; + + source.Append( + @$" + var ______arguments = {argumentsArrayString}; + var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodModel.Name}"", {parameterTypesExpression}{genericString} ); + + {@return}({returnType})______func(this.Client, ______arguments){configureAwait}; + " + ); + + WriteMethodClosing(source); + } + + private static void WriteNonRefitMethod(StringBuilder source, MethodModel methodModel) + { + WriteMethodOpening(source, methodModel, true); + + source.Append( + @" + throw new global::System.NotImplementedException(""Either this method has no Refit HTTP method attribute or you've used something other than a string literal for the 'path' argument.""); + " + ); + + WriteMethodClosing(source); + } + + // TODO: This assumes that the Dispose method is a void that takes no parameters. + // The previous version did not. + // Does the bool overload cause an issue here. + private static void WriteDisposableMethod(StringBuilder source) + { + source.Append( + """ + + + /// + void global::System.IDisposable.Dispose() + { + Client?.Dispose(); + } + """ + ); + } + + private static string GenerateTypeParameterExpression( + StringBuilder source, + MethodModel methodModel, + HashSet memberNames + ) + { + // use Array.Empty if method has no parameters. + if (methodModel.Parameters.Count == 0) + return "global::System.Array.Empty()"; + + // if one of the parameters is/contains a type parameter then it cannot be cached as it will change type between calls. + if (methodModel.Parameters.Any(x => x.IsGeneric)) + { + var typeEnumerable = methodModel.Parameters.Select(param => $"typeof({param.Type})"); + return $"new global::System.Type[] {{ {string.Join(", ", typeEnumerable)} }}"; + } + + // find a name and generate field declaration. + var typeParameterFieldName = UniqueName(TypeParameterVariableName, memberNames); + var types = string.Join(", ", methodModel.Parameters.Select(x => $"typeof({x.Type})")); + source.Append( + $$""" + + + private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} }; + """ + ); + + return typeParameterFieldName; + } + + private static void WriteMethodOpening( + StringBuilder source, + MethodModel methodModel, + bool isExplicitInterface, + bool isAsync = false + ) + { + var visibility = !isExplicitInterface ? "public " : string.Empty; + var async = isAsync ? "async " : ""; + + source.Append( + @$" + + /// + {visibility}{async}{methodModel.ReturnType} " + ); + + if (isExplicitInterface) + { + source.Append(@$"{methodModel.ContainingType}."); + } + source.Append(@$"{methodModel.DeclaredMethod}("); + + if (methodModel.Parameters.Count > 0) + { + var list = new List(); + foreach (var param in methodModel.Parameters) + { + var annotation = param.Annotation; + + list.Add($@"{param.Type}{(annotation ? '?' : string.Empty)} @{param.MetadataName}"); + } + + source.Append(string.Join(", ", list)); + } + + source.Append( + @$"){GenerateConstraints(methodModel.Constraints, isExplicitInterface)} + {{" + ); + } + + private static void WriteMethodClosing(StringBuilder source) => source.Append(@" }"); + + private static string UniqueName(string name, HashSet methodNames) + { + var candidateName = name; + var counter = 0; + while (methodNames.Contains(candidateName)) + { + candidateName = $"{name}{counter}"; + counter++; + } + + methodNames.Add(candidateName); + return candidateName; + } + + private static string GenerateConstraints( + ImmutableEquatableArray typeParameters, + bool isOverrideOrExplicitImplementation + ) + { + var source = new StringBuilder(); + // Need to loop over the constraints and create them + foreach (var typeParameter in typeParameters) + { + WriteConstraintsForTypeParameter( + source, + typeParameter, + isOverrideOrExplicitImplementation + ); + } + + return source.ToString(); + } + + private static void WriteConstraintsForTypeParameter( + StringBuilder source, + TypeConstraint typeParameter, + bool isOverrideOrExplicitImplementation + ) + { + // Explicit interface implementations and overrides can only have class or struct constraints + + var parameters = new List(); + var knownConstraints = typeParameter.KnownTypeConstraint; + if (knownConstraints.HasFlag(KnownTypeConstraint.Class)) + { + parameters.Add("class"); + } + if ( + knownConstraints.HasFlag(KnownTypeConstraint.Unmanaged) + && !isOverrideOrExplicitImplementation + ) + { + parameters.Add("unmanaged"); + } + if (knownConstraints.HasFlag(KnownTypeConstraint.Struct)) + { + parameters.Add("struct"); + } + if ( + knownConstraints.HasFlag(KnownTypeConstraint.NotNull) + && !isOverrideOrExplicitImplementation + ) + { + parameters.Add("notnull"); + } + if (!isOverrideOrExplicitImplementation) + { + parameters.AddRange(typeParameter.Constraints); + } + + // new constraint has to be last + if ( + knownConstraints.HasFlag(KnownTypeConstraint.New) && !isOverrideOrExplicitImplementation + ) + { + parameters.Add("new()"); + } + + if (parameters.Count > 0) + { + source.Append( + @$" + where {typeParameter.TypeName} : {string.Join(", ", parameters)}" + ); + } + } +} diff --git a/InterfaceStubGenerator.Shared/ImmutableEquatableArray.cs b/InterfaceStubGenerator.Shared/ImmutableEquatableArray.cs new file mode 100644 index 000000000..23c1d9009 --- /dev/null +++ b/InterfaceStubGenerator.Shared/ImmutableEquatableArray.cs @@ -0,0 +1,93 @@ +using System.Collections; + +namespace Refit.Generator; + +internal static class ImmutableEquatableArray +{ + public static ImmutableEquatableArray Empty() + where T : IEquatable => ImmutableEquatableArray.Empty; + + public static ImmutableEquatableArray ToImmutableEquatableArray( + this IEnumerable? values + ) + where T : IEquatable => values == null ? Empty() : new(values); +} + +/// +/// Provides an immutable list implementation which implements sequence equality. +/// +internal sealed class ImmutableEquatableArray + : IEquatable>, + IReadOnlyList + where T : IEquatable +{ + public static ImmutableEquatableArray Empty { get; } = new(Array.Empty()); + + private readonly T[] _values; + public T this[int index] => _values[index]; + public int Count => _values.Length; + + public ImmutableEquatableArray(T[] values) => _values = values; + + public ImmutableEquatableArray(IEnumerable values) => _values = values.ToArray(); + + public T[] AsArray() => _values; + + public bool Equals(ImmutableEquatableArray? other) => + other != null && ((ReadOnlySpan)_values).SequenceEqual(other._values); + + public override bool Equals(object? obj) => + obj is ImmutableEquatableArray other && Equals(other); + + public override int GetHashCode() + { + var hash = 0; + foreach (T value in _values) + { + hash = Combine(hash, value.GetHashCode()); + } + + static int Combine(int h1, int h2) + { + // RyuJIT optimizes this to use the ROL instruction + // Related GitHub pull request: https://github.com/dotnet/coreclr/pull/1830 + uint rol5 = ((uint)h1 << 5) | ((uint)h1 >> 27); + return ((int)rol5 + h1) ^ h2; + } + + return hash; + } + + public Enumerator GetEnumerator() => new(_values); + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_values).GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => _values.GetEnumerator(); + + public struct Enumerator + { + private readonly T[] _values; + private int _index; + + internal Enumerator(T[] values) + { + _values = values; + _index = -1; + } + + public bool MoveNext() + { + var newIndex = _index + 1; + + if ((uint)newIndex < (uint)_values.Length) + { + _index = newIndex; + return true; + } + + return false; + } + + public readonly T Current => _values[_index]; + } +} diff --git a/InterfaceStubGenerator.Shared/IncrementalValuesProviderExtensions.cs b/InterfaceStubGenerator.Shared/IncrementalValuesProviderExtensions.cs new file mode 100644 index 000000000..0766c4c93 --- /dev/null +++ b/InterfaceStubGenerator.Shared/IncrementalValuesProviderExtensions.cs @@ -0,0 +1,68 @@ +#if ROSLYN_4 +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; + +namespace Refit.Generator; + +internal static class IncrementalValuesProviderExtensions +{ + /// + /// Registers an output node into an to output a diagnostic. + /// + /// The input instance. + /// The input sequence of diagnostics. + public static void ReportDiagnostics( + this IncrementalGeneratorInitializationContext context, + IncrementalValuesProvider diagnostic + ) + { + context.RegisterSourceOutput( + diagnostic, + static (context, diagnostic) => context.ReportDiagnostic(diagnostic) + ); + } + + /// + /// Registers an output node into an to output diagnostics. + /// + /// The input instance. + /// The input sequence of diagnostics. + public static void ReportDiagnostics( + this IncrementalGeneratorInitializationContext context, + IncrementalValueProvider> diagnostics + ) + { + context.RegisterSourceOutput( + diagnostics, + static (context, diagnostics) => + { + foreach (var diagnostic in diagnostics) + { + context.ReportDiagnostic(diagnostic); + } + } + ); + } + + /// + /// Registers an implementation source output for the provided mappers. + /// + /// The context, on which the output is registered. + /// The interfaces stubs. + public static void EmitSource( + this IncrementalGeneratorInitializationContext context, + IncrementalValuesProvider model + ) + { + context.RegisterImplementationSourceOutput( + model, + static (spc, model) => + { + var mapperText = Emitter.EmitInterface(model); + spc.AddSource(model.FileName, SourceText.From(mapperText, Encoding.UTF8)); + } + ); + } +} +#endif diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems index 7b3b9302c..ea53907fe 100644 --- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems +++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems @@ -10,7 +10,17 @@ + + + + + + + + + + \ No newline at end of file diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs index 0b627f162..d33ffc17d 100644 --- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs +++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs @@ -43,735 +43,37 @@ public void Execute(GeneratorExecutionContext context) out var refitInternalNamespace ); - GenerateInterfaceStubs( - context, - static (context, diagnostic) => context.ReportDiagnostic(diagnostic), - static (context, hintName, sourceText) => context.AddSource(hintName, sourceText), + var parseStep = Parser.GenerateInterfaceStubs( (CSharpCompilation)context.Compilation, refitInternalNamespace, receiver.CandidateMethods.ToImmutableArray(), - receiver.CandidateInterfaces.ToImmutableArray() + receiver.CandidateInterfaces.ToImmutableArray(), + context.CancellationToken ); - } -#endif - - /// - /// Generates the interface stubs. - /// - /// The type of the context. - /// The context. - /// The report diagnostic. - /// The add source. - /// The compilation. - /// The refit internal namespace. - /// The candidate methods. - /// The candidate interfaces. - /// - public void GenerateInterfaceStubs( - TContext context, - Action reportDiagnostic, - Action addSource, - CSharpCompilation compilation, - string? refitInternalNamespace, - ImmutableArray candidateMethods, - ImmutableArray candidateInterfaces - ) - { - if (compilation == null) - throw new ArgumentNullException(nameof(compilation)); - - if (reportDiagnostic == null) - throw new ArgumentNullException(nameof(reportDiagnostic)); - - if (addSource == null) - throw new ArgumentNullException(nameof(addSource)); - - refitInternalNamespace = - $"{refitInternalNamespace ?? string.Empty}RefitInternalGenerated"; - - // we're going to create a new compilation that contains the attribute. - // TODO: we should allow source generators to provide source during initialize, so that this step isn't required. - var options = (CSharpParseOptions)compilation.SyntaxTrees[0].Options; - - var disposableInterfaceSymbol = compilation.GetTypeByMetadataName( - "System.IDisposable" - )!; - var httpMethodBaseAttributeSymbol = compilation.GetTypeByMetadataName( - "Refit.HttpMethodAttribute" - ); - - if (httpMethodBaseAttributeSymbol == null) - { - reportDiagnostic(context, Diagnostic.Create(DiagnosticDescriptors.RefitNotReferenced, null)); - return; - } - - // Check the candidates and keep the ones we're actually interested in - -#pragma warning disable RS1024 // Compare symbols correctly - var interfaceToNullableEnabledMap = new Dictionary( - SymbolEqualityComparer.Default - ); -#pragma warning restore RS1024 // Compare symbols correctly - var methodSymbols = new List(); - foreach (var group in candidateMethods.GroupBy(m => m.SyntaxTree)) - { - var model = compilation.GetSemanticModel(group.Key); - foreach (var method in group) - { - // Get the symbol being declared by the method - var methodSymbol = model.GetDeclaredSymbol(method); - if (IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol)) - { - var isAnnotated = - compilation.Options.NullableContextOptions - == NullableContextOptions.Enable - || model.GetNullableContext(method.SpanStart) - == NullableContext.Enabled; - interfaceToNullableEnabledMap[methodSymbol!.ContainingType] = isAnnotated; - - methodSymbols.Add(methodSymbol!); - } - } - } - - var interfaces = methodSymbols - .GroupBy( - m => m.ContainingType, - SymbolEqualityComparer.Default - ) - .ToDictionary, INamedTypeSymbol, List>( - g => g.Key, - v => [.. v], - SymbolEqualityComparer.Default - ); - - // Look through the candidate interfaces - var interfaceSymbols = new List(); - foreach (var group in candidateInterfaces.GroupBy(i => i.SyntaxTree)) - { - var model = compilation.GetSemanticModel(group.Key); - foreach (var iface in group) - { - // get the symbol belonging to the interface - var ifaceSymbol = model.GetDeclaredSymbol(iface); - - // See if we already know about it, might be a dup - if (ifaceSymbol is null || interfaces.ContainsKey(ifaceSymbol)) - continue; - - // The interface has no refit methods, but its base interfaces might - var hasDerivedRefit = ifaceSymbol - .AllInterfaces.SelectMany(i => i.GetMembers().OfType()) - .Any(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)); - - if (hasDerivedRefit) - { - // Add the interface to the generation list with an empty set of methods - // The logic already looks for base refit methods - interfaces.Add(ifaceSymbol, []); - var isAnnotated = - model.GetNullableContext(iface.SpanStart) == NullableContext.Enabled; - - interfaceToNullableEnabledMap[ifaceSymbol] = isAnnotated; - } - } - } - - // Bail out if there aren't any interfaces to generate code for. This may be the case with transitives - if (interfaces.Count == 0) - return; - - var supportsNullable = options.LanguageVersion >= LanguageVersion.CSharp8; - - var keyCount = new Dictionary(); - - var attributeText = - @$" -#pragma warning disable -namespace {refitInternalNamespace} -{{ - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - [global::System.AttributeUsage (global::System.AttributeTargets.Class | global::System.AttributeTargets.Struct | global::System.AttributeTargets.Enum | global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Event | global::System.AttributeTargets.Interface | global::System.AttributeTargets.Delegate)] - sealed class PreserveAttribute : global::System.Attribute - {{ - // - // Fields - // - public bool AllMembers; - - public bool Conditional; - }} -}} -#pragma warning restore -"; - - compilation = compilation.AddSyntaxTrees( - CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options) - ); - - // add the attribute text - addSource( - context, - "PreserveAttribute.g.cs", - SourceText.From(attributeText, Encoding.UTF8) - ); - - // get the newly bound attribute - var preserveAttributeSymbol = compilation.GetTypeByMetadataName( - $"{refitInternalNamespace}.PreserveAttribute" - )!; - - var generatedClassText = - @$" -#pragma warning disable -namespace Refit.Implementation -{{ - - /// - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.Diagnostics.DebuggerNonUserCode] - [{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}] - [global::System.Reflection.Obfuscation(Exclude=true)] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - internal static partial class Generated - {{ -#if NET5_0_OR_GREATER - [System.Runtime.CompilerServices.ModuleInitializer] - [System.Diagnostics.CodeAnalysis.DynamicDependency(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All, typeof(global::Refit.Implementation.Generated))] - public static void Initialize() - {{ - }} -#endif - }} -}} -#pragma warning restore -"; - addSource( - context, - "Generated.g.cs", - SourceText.From(generatedClassText, Encoding.UTF8) - ); - - compilation = compilation.AddSyntaxTrees( - CSharpSyntaxTree.ParseText( - SourceText.From(generatedClassText, Encoding.UTF8), - options - ) - ); - - // group the fields by interface and generate the source - foreach (var group in interfaces) - { - // each group is keyed by the Interface INamedTypeSymbol and contains the members - // with a refit attribute on them. Types may contain other members, without the attribute, which we'll - // need to check for and error out on - - var classSource = ProcessInterface( - context, - reportDiagnostic, - group.Key, - group.Value, - preserveAttributeSymbol, - disposableInterfaceSymbol, - httpMethodBaseAttributeSymbol, - supportsNullable, - interfaceToNullableEnabledMap[group.Key] - ); - - var keyName = group.Key.Name; - int value; - while (keyCount.TryGetValue(keyName, out value)) - { - keyName = $"{keyName}{++value}"; - } - keyCount[keyName] = value; - - addSource(context, $"{keyName}.g.cs", SourceText.From(classSource, Encoding.UTF8)); - } - } - - static string ProcessInterface( - TContext context, - Action reportDiagnostic, - INamedTypeSymbol interfaceSymbol, - List refitMethods, - ISymbol preserveAttributeSymbol, - ISymbol disposableInterfaceSymbol, - INamedTypeSymbol httpMethodBaseAttributeSymbol, - bool supportsNullable, - bool nullableEnabled - ) - { - // Get the class name with the type parameters, then remove the namespace - var className = interfaceSymbol.ToDisplayString(); - var lastDot = className.LastIndexOf('.'); - if (lastDot > 0) - { - className = className.Substring(lastDot + 1); - } - var classDeclaration = $"{interfaceSymbol.ContainingType?.Name}{className}"; - - // Get the class name itself - var classSuffix = $"{interfaceSymbol.ContainingType?.Name}{interfaceSymbol.Name}"; - var ns = interfaceSymbol.ContainingNamespace?.ToDisplayString(); - - // if it's the global namespace, our lookup rules say it should be the same as the class name - if ( - interfaceSymbol.ContainingNamespace != null - && interfaceSymbol.ContainingNamespace.IsGlobalNamespace - ) - { - ns = string.Empty; - } - - // Remove dots - ns = ns!.Replace(".", ""); - - // See what the nullable context is - - - var source = new StringBuilder(); - if (supportsNullable) - { - source.Append("#nullable "); - - if (nullableEnabled) - { - source.Append("enable"); - } - else - { - source.Append("disable"); - } - } - - source.Append( - $@" -#pragma warning disable -namespace Refit.Implementation -{{ - - partial class Generated - {{ - - /// - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.Diagnostics.DebuggerNonUserCode] - [{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}] - [global::System.Reflection.Obfuscation(Exclude=true)] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - partial class {ns}{classDeclaration} - : {interfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}{GenerateConstraints(interfaceSymbol.TypeParameters, false)} - - {{ - /// - public global::System.Net.Http.HttpClient Client {{ get; }} - readonly global::Refit.IRequestBuilder requestBuilder; - - /// - public {ns}{classSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) - {{ - Client = client; - this.requestBuilder = requestBuilder; - }} -" - ); - // Get any other methods on the refit interfaces. We'll need to generate something for them and warn - var nonRefitMethods = interfaceSymbol - .GetMembers() - .OfType() - .Except(refitMethods, SymbolEqualityComparer.Default) - .Cast() - .ToList(); - - // get methods for all inherited - var derivedMethods = interfaceSymbol - .AllInterfaces.SelectMany(i => i.GetMembers().OfType()) - .ToList(); - - // Look for disposable - var disposeMethod = derivedMethods.Find( - m => - m.ContainingType?.Equals( - disposableInterfaceSymbol, - SymbolEqualityComparer.Default - ) == true - ); - if (disposeMethod != null) - { - //remove it from the derived methods list so we don't process it with the rest - derivedMethods.Remove(disposeMethod); - } - - // Pull out the refit methods from the derived types - var derivedRefitMethods = derivedMethods - .Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)) - .ToList(); - var derivedNonRefitMethods = derivedMethods - .Except(derivedMethods, SymbolEqualityComparer.Default) - .Cast() - .ToList(); - - var memberNames = new HashSet(interfaceSymbol.GetMembers().Select(x => x.Name)); - - // Handle Refit Methods - foreach (var method in refitMethods) - { - ProcessRefitMethod(source, method, true, memberNames); - } - - foreach (var method in refitMethods.Concat(derivedRefitMethods)) - { - ProcessRefitMethod(source, method, false, memberNames); - } - - // Handle non-refit Methods that aren't static or properties or have a method body - foreach (var method in nonRefitMethods.Concat(derivedNonRefitMethods)) - { - if ( - method.IsStatic - || method.MethodKind == MethodKind.PropertyGet - || method.MethodKind == MethodKind.PropertySet - || !method.IsAbstract - ) // If an interface method has a body, it won't be abstract - continue; - - ProcessNonRefitMethod(context, reportDiagnostic, source, method); - } - - // Handle Dispose - if (disposeMethod != null) - { - ProcessDisposableMethod(source, disposeMethod); - } - - source.Append( - @" - } - } -} - -#pragma warning restore -" - ); - return source.ToString(); - } - - /// - /// Generates the body of the Refit method - /// - /// - /// - /// True if directly from the type we're generating for, false for methods found on base interfaces - /// Contains the unique member names in the interface scope. - static void ProcessRefitMethod( - StringBuilder source, - IMethodSymbol methodSymbol, - bool isTopLevel, - HashSet memberNames - ) - { - var parameterTypesExpression = GenerateTypeParameterExpression( - source, - methodSymbol, - memberNames - ); - - var returnType = methodSymbol.ReturnType.ToDisplayString( - SymbolDisplayFormat.FullyQualifiedFormat - ); - var (isAsync, @return, configureAwait) = methodSymbol.ReturnType.MetadataName switch - { - "Task" => (true, "await (", ").ConfigureAwait(false)"), - "Task`1" or "ValueTask`1" => (true, "return await (", ").ConfigureAwait(false)"), - _ => (false, "return ", ""), - }; - - WriteMethodOpening(source, methodSymbol, !isTopLevel, isAsync); - - // Build the list of args for the array - var argList = new List(); - foreach (var param in methodSymbol.Parameters) - { - argList.Add($"@{param.MetadataName}"); - } - - // List of generic arguments - var genericList = new List(); - foreach (var typeParam in methodSymbol.TypeParameters) - { - genericList.Add( - $"typeof({typeParam.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})" - ); - } - - var argumentsArrayString = - argList.Count == 0 - ? "global::System.Array.Empty()" - : $"new object[] {{ {string.Join(", ", argList)} }}"; - - var genericString = - genericList.Count > 0 - ? $", new global::System.Type[] {{ {string.Join(", ", genericList)} }}" - : string.Empty; - - source.Append( - @$" - var ______arguments = {argumentsArrayString}; - var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesExpression}{genericString} ); - - {@return}({returnType})______func(this.Client, ______arguments){configureAwait}; -" - ); - - WriteMethodClosing(source); - } - - static void ProcessDisposableMethod(StringBuilder source, IMethodSymbol methodSymbol) - { - WriteMethodOpening(source, methodSymbol, true); - - source.Append( - @" - Client?.Dispose(); -" - ); - - WriteMethodClosing(source); - } - - static string GenerateConstraints( - ImmutableArray typeParameters, - bool isOverrideOrExplicitImplementation - ) - { - var source = new StringBuilder(); - // Need to loop over the constraints and create them - foreach (var typeParameter in typeParameters) - { - WriteConstraintsForTypeParameter( - source, - typeParameter, - isOverrideOrExplicitImplementation - ); - } - - return source.ToString(); - } - - static void WriteConstraintsForTypeParameter( - StringBuilder source, - ITypeParameterSymbol typeParameter, - bool isOverrideOrExplicitImplementation - ) - { - // Explicit interface implementations and overrides can only have class or struct constraints - - var parameters = new List(); - if (typeParameter.HasReferenceTypeConstraint) - { - parameters.Add("class"); - } - if (typeParameter.HasUnmanagedTypeConstraint && !isOverrideOrExplicitImplementation) - { - parameters.Add("unmanaged"); - } - - // unmanaged constraints are both structs and unmanaged so the struct constraint is redundant - if (typeParameter.HasValueTypeConstraint && !typeParameter.HasUnmanagedTypeConstraint) - { - parameters.Add("struct"); - } - if (typeParameter.HasNotNullConstraint && !isOverrideOrExplicitImplementation) - { - parameters.Add("notnull"); - } - if (!isOverrideOrExplicitImplementation) - { - foreach (var typeConstraint in typeParameter.ConstraintTypes) - { - parameters.Add( - typeConstraint.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) - ); - } - } - - // new constraint has to be last - if (typeParameter.HasConstructorConstraint && !isOverrideOrExplicitImplementation) - { - parameters.Add("new()"); - } - - if (parameters.Count > 0) - { - source.Append( - @$" - where {typeParameter.Name} : {string.Join(", ", parameters)}" - ); - } - } - static void ProcessNonRefitMethod( - TContext context, - Action reportDiagnostic, - StringBuilder source, - IMethodSymbol methodSymbol - ) - { - WriteMethodOpening(source, methodSymbol, true); - - source.Append( - @" - throw new global::System.NotImplementedException(""Either this method has no Refit HTTP method attribute or you've used something other than a string literal for the 'path' argument.""); -" - ); - - WriteMethodClosing(source); - - foreach (var location in methodSymbol.Locations) - { - var diagnostic = Diagnostic.Create( - DiagnosticDescriptors.InvalidRefitMember, - location, - methodSymbol.ContainingType.Name, - methodSymbol.Name - ); - reportDiagnostic(context, diagnostic); - } - } - - static string GenerateTypeParameterExpression( - StringBuilder source, - IMethodSymbol methodSymbol, - HashSet memberNames - ) - { - // use Array.Empty if method has no parameters. - if (methodSymbol.Parameters.Length == 0) - return "global::System.Array.Empty()"; - - // if one of the parameters is/contains a type parameter then it cannot be cached as it will change type between calls. - if (methodSymbol.Parameters.Any(x => ContainsTypeParameter(x.Type))) - { - var typeEnumerable = methodSymbol.Parameters.Select( - param => - $"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})" - ); - return $"new global::System.Type[] {{ {string.Join(", ", typeEnumerable)} }}"; - } - - // find a name and generate field declaration. - var typeParameterFieldName = UniqueName(TypeParameterVariableName, memberNames); - var types = string.Join( - ", ", - methodSymbol.Parameters.Select( - x => - $"typeof({x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})" - ) - ); - source.Append( - $$""" - - - private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} }; - """ - ); - - return typeParameterFieldName; - - static bool ContainsTypeParameter(ITypeSymbol symbol) + // Emit diagnostics + foreach (var diagnostic in parseStep.diagnostics) { - if (symbol is ITypeParameterSymbol) - return true; - - if (symbol is not INamedTypeSymbol { TypeParameters.Length: > 0 } namedType) - return false; - - foreach (var typeArg in namedType.TypeArguments) - { - if (ContainsTypeParameter(typeArg)) - return true; - } - - return false; + context.ReportDiagnostic(diagnostic); } - } - - static void WriteMethodOpening( - StringBuilder source, - IMethodSymbol methodSymbol, - bool isExplicitInterface, - bool isAsync = false - ) - { - var visibility = !isExplicitInterface ? "public " : string.Empty; - var async = isAsync ? "async " : ""; - - source.Append( - @$" - - /// - {visibility}{async}{methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} " - ); - if (isExplicitInterface) + // Emit interface stubs + foreach (var interfaceModel in parseStep.contextGenerationSpec.Interfaces) { - source.Append( - @$"{methodSymbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}." + var interfaceText = Emitter.EmitInterface(interfaceModel); + context.AddSource( + interfaceModel.FileName, + SourceText.From(interfaceText, Encoding.UTF8) ); } - source.Append( - @$"{methodSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}(" - ); - - if (methodSymbol.Parameters.Length > 0) - { - var list = new List(); - foreach (var param in methodSymbol.Parameters) - { - var annotation = - !param.Type.IsValueType - && param.NullableAnnotation == NullableAnnotation.Annotated; - list.Add( - $@"{param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}{(annotation ? '?' : string.Empty)} @{param.MetadataName}" - ); - } - - source.Append(string.Join(", ", list)); - } - - source.Append( - @$"){GenerateConstraints(methodSymbol.TypeParameters, isExplicitInterface)} - {{" + // Emit PreserveAttribute and Generated.Initialize + Emitter.EmitSharedCode( + parseStep.contextGenerationSpec, + (name, code) => context.AddSource(name, code) ); } - - static void WriteMethodClosing(StringBuilder source) => source.Append(@" }"); - - static string UniqueName(string name, HashSet methodNames) - { - var candidateName = name; - var counter = 0; - while (methodNames.Contains(candidateName)) - { - candidateName = $"{name}{counter}"; - counter++; - } - - methodNames.Add(candidateName); - return candidateName; - } - - static bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttribute) - { - return methodSymbol - ?.GetAttributes() - .Any(ad => ad.AttributeClass?.InheritsFromOrEquals(httpMethodAttribute) == true) - == true; - } +#endif #if ROSLYN_4 @@ -831,22 +133,40 @@ out var refitInternalNamespace ) ); - context.RegisterSourceOutput( - inputs, - (context, collectedValues) => + var parseStep = inputs.Select( + (collectedValues, cancellationToken) => { - GenerateInterfaceStubs( - context, - static (context, diagnostic) => context.ReportDiagnostic(diagnostic), - static (context, hintName, sourceText) => - context.AddSource(hintName, sourceText), + return Parser.GenerateInterfaceStubs( (CSharpCompilation)collectedValues.compilation, collectedValues.refitInternalNamespace, collectedValues.candidateMethods, - collectedValues.candidateInterfaces + collectedValues.candidateInterfaces, + cancellationToken ); } ); + + // output the diagnostics + // use `ImmutableEquatableArray` to cache cases where there are no diagnostics + // otherwise the subsequent steps will always rerun. + var diagnostics = parseStep + .Select(static (x, _) => x.diagnostics.ToImmutableEquatableArray()) + .WithTrackingName(RefitGeneratorStepName.ReportDiagnostics); + context.ReportDiagnostics(diagnostics); + + var contextModel = parseStep.Select(static (x, _) => x.Item2); + var interfaceModels = contextModel + .SelectMany(static (x, _) => x.Interfaces) + .WithTrackingName(RefitGeneratorStepName.BuildRefit); + context.EmitSource(interfaceModels); + + context.RegisterImplementationSourceOutput( + contextModel, + static (spc, model) => + { + Emitter.EmitSharedCode(model, (name, code) => spc.AddSource(name, code)); + } + ); } #else diff --git a/InterfaceStubGenerator.Shared/IsExternalInit.cs b/InterfaceStubGenerator.Shared/IsExternalInit.cs new file mode 100644 index 000000000..408e7abd0 --- /dev/null +++ b/InterfaceStubGenerator.Shared/IsExternalInit.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP2_0 || NETCOREAPP2_1 || NETCOREAPP2_2 || NETCOREAPP3_0 || NETCOREAPP3_1 || NET45 || NET451 || NET452 || NET6 || NET461 || NET462 || NET47 || NET471 || NET472 || NET48 + +using System.ComponentModel; + +// ReSharper disable once CheckNamespace +namespace System.Runtime.CompilerServices +{ + /// + /// Reserved to be used by the compiler for tracking metadata. + /// This class should not be used by developers in source code. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + internal static class IsExternalInit { } +} + +#endif diff --git a/InterfaceStubGenerator.Shared/Models/ContextGenerationModel.cs b/InterfaceStubGenerator.Shared/Models/ContextGenerationModel.cs new file mode 100644 index 000000000..743e5b9cd --- /dev/null +++ b/InterfaceStubGenerator.Shared/Models/ContextGenerationModel.cs @@ -0,0 +1,7 @@ +namespace Refit.Generator; + +internal sealed record ContextGenerationModel( + string RefitInternalNamespace, + string PreserveAttributeDisplayName, + ImmutableEquatableArray Interfaces +); diff --git a/InterfaceStubGenerator.Shared/Models/InterfaceModel.cs b/InterfaceStubGenerator.Shared/Models/InterfaceModel.cs new file mode 100644 index 000000000..aa1a875c5 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Models/InterfaceModel.cs @@ -0,0 +1,25 @@ +namespace Refit.Generator; + +internal sealed record InterfaceModel( + string PreserveAttributeDisplayName, + string FileName, + string ClassName, + string Ns, + string ClassDeclaration, + string InterfaceDisplayName, + string ClassSuffix, + ImmutableEquatableArray Constraints, + ImmutableEquatableArray MemberNames, + ImmutableEquatableArray NonRefitMethods, + ImmutableEquatableArray RefitMethods, + ImmutableEquatableArray DerivedRefitMethods, + Nullability Nullability, + bool DisposeMethod +); + +internal enum Nullability : byte +{ + Enabled, + Disabled, + None +} diff --git a/InterfaceStubGenerator.Shared/Models/MethodModel.cs b/InterfaceStubGenerator.Shared/Models/MethodModel.cs new file mode 100644 index 000000000..513a8b93b --- /dev/null +++ b/InterfaceStubGenerator.Shared/Models/MethodModel.cs @@ -0,0 +1,18 @@ +namespace Refit.Generator; + +internal sealed record MethodModel( + string Name, + string ReturnType, + string ContainingType, + string DeclaredMethod, + ReturnTypeInfo ReturnTypeMetadata, + ImmutableEquatableArray Parameters, + ImmutableEquatableArray Constraints +); + +internal enum ReturnTypeInfo : byte +{ + Return, + AsyncVoid, + AsyncResult +} diff --git a/InterfaceStubGenerator.Shared/Models/ParameterModel.cs b/InterfaceStubGenerator.Shared/Models/ParameterModel.cs new file mode 100644 index 000000000..64349b392 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Models/ParameterModel.cs @@ -0,0 +1,8 @@ +namespace Refit.Generator; + +internal sealed record ParameterModel( + string MetadataName, + string Type, + bool Annotation, + bool IsGeneric +); diff --git a/InterfaceStubGenerator.Shared/Models/TypeConstraint.cs b/InterfaceStubGenerator.Shared/Models/TypeConstraint.cs new file mode 100644 index 000000000..da386ec23 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Models/TypeConstraint.cs @@ -0,0 +1,19 @@ +namespace Refit.Generator; + +internal readonly record struct TypeConstraint( + string TypeName, + string DeclaredName, + KnownTypeConstraint KnownTypeConstraint, + ImmutableEquatableArray Constraints +); + +[Flags] +internal enum KnownTypeConstraint : byte +{ + None = 0, + Class = 1 << 0, + Unmanaged = 1 << 1, + Struct = 1 << 2, + NotNull = 1 << 3, + New = 1 << 4 +} diff --git a/InterfaceStubGenerator.Shared/Parser.cs b/InterfaceStubGenerator.Shared/Parser.cs new file mode 100644 index 000000000..04a5e7b97 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Parser.cs @@ -0,0 +1,520 @@ +using System.Collections.Immutable; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; + +namespace Refit.Generator; + +internal static class Parser +{ + /// + /// Generates the interface stubs. + /// + /// The compilation. + /// The refit internal namespace. + /// The candidate methods. + /// The candidate interfaces. + /// The cancellation token. + /// + public static ( + List diagnostics, + ContextGenerationModel contextGenerationSpec + ) GenerateInterfaceStubs( + CSharpCompilation compilation, + string? refitInternalNamespace, + ImmutableArray candidateMethods, + ImmutableArray candidateInterfaces, + CancellationToken cancellationToken + ) + { + if (compilation == null) + throw new ArgumentNullException(nameof(compilation)); + + refitInternalNamespace = $"{refitInternalNamespace ?? string.Empty}RefitInternalGenerated"; + + // we're going to create a new compilation that contains the attribute. + // TODO: we should allow source generators to provide source during initialize, so that this step isn't required. + var options = (CSharpParseOptions)compilation.SyntaxTrees[0].Options; + + var disposableInterfaceSymbol = compilation.GetTypeByMetadataName("System.IDisposable")!; + var httpMethodBaseAttributeSymbol = compilation.GetTypeByMetadataName( + "Refit.HttpMethodAttribute" + ); + + var diagnostics = new List(); + if (httpMethodBaseAttributeSymbol == null) + { + diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.RefitNotReferenced, null)); + return ( + diagnostics, + new ContextGenerationModel( + refitInternalNamespace, + string.Empty, + ImmutableEquatableArray.Empty() + ) + ); + } + + // Check the candidates and keep the ones we're actually interested in + +#pragma warning disable RS1024 // Compare symbols correctly + var interfaceToNullableEnabledMap = new Dictionary( + SymbolEqualityComparer.Default + ); +#pragma warning restore RS1024 // Compare symbols correctly + var methodSymbols = new List(); + foreach (var group in candidateMethods.GroupBy(m => m.SyntaxTree)) + { + var model = compilation.GetSemanticModel(group.Key); + foreach (var method in group) + { + // Get the symbol being declared by the method + var methodSymbol = model.GetDeclaredSymbol( + method, + cancellationToken: cancellationToken + ); + if (!IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol)) + continue; + + var isAnnotated = + compilation.Options.NullableContextOptions == NullableContextOptions.Enable + || model.GetNullableContext(method.SpanStart) == NullableContext.Enabled; + interfaceToNullableEnabledMap[methodSymbol!.ContainingType] = isAnnotated; + + methodSymbols.Add(methodSymbol!); + } + } + + var interfaces = methodSymbols + .GroupBy( + m => m.ContainingType, + SymbolEqualityComparer.Default + ) + .ToDictionary< + IGrouping, + INamedTypeSymbol, + List + >(g => g.Key, v => [.. v], SymbolEqualityComparer.Default); + + // Look through the candidate interfaces + var interfaceSymbols = new List(); + foreach (var group in candidateInterfaces.GroupBy(i => i.SyntaxTree)) + { + var model = compilation.GetSemanticModel(group.Key); + foreach (var iface in group) + { + // get the symbol belonging to the interface + var ifaceSymbol = model.GetDeclaredSymbol( + iface, + cancellationToken: cancellationToken + ); + + // See if we already know about it, might be a dup + if (ifaceSymbol is null || interfaces.ContainsKey(ifaceSymbol)) + continue; + + // The interface has no refit methods, but its base interfaces might + var hasDerivedRefit = ifaceSymbol + .AllInterfaces.SelectMany(i => i.GetMembers().OfType()) + .Any(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)); + + if (hasDerivedRefit) + { + // Add the interface to the generation list with an empty set of methods + // The logic already looks for base refit methods + interfaces.Add(ifaceSymbol, []); + var isAnnotated = + model.GetNullableContext(iface.SpanStart) == NullableContext.Enabled; + + interfaceToNullableEnabledMap[ifaceSymbol] = isAnnotated; + } + } + } + + cancellationToken.ThrowIfCancellationRequested(); + + // Bail out if there aren't any interfaces to generate code for. This may be the case with transitives + if (interfaces.Count == 0) + return ( + diagnostics, + new ContextGenerationModel( + refitInternalNamespace, + string.Empty, + ImmutableEquatableArray.Empty() + ) + ); + + var supportsNullable = options.LanguageVersion >= LanguageVersion.CSharp8; + + var keyCount = new Dictionary(); + + var attributeText = + @$" +#pragma warning disable +namespace {refitInternalNamespace} +{{ + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + [global::System.AttributeUsage (global::System.AttributeTargets.Class | global::System.AttributeTargets.Struct | global::System.AttributeTargets.Enum | global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Event | global::System.AttributeTargets.Interface | global::System.AttributeTargets.Delegate)] + sealed class PreserveAttribute : global::System.Attribute + {{ + // + // Fields + // + public bool AllMembers; + + public bool Conditional; + }} +}} +#pragma warning restore +"; + + // TODO: Delete? + // Is it necessary to add the attributes to the compilation now, does it affect the users ide experience? + // Is it needed in order to get the preserve attribute display name. + // Will the compilation ever change this. + compilation = compilation.AddSyntaxTrees( + CSharpSyntaxTree.ParseText( + SourceText.From(attributeText, Encoding.UTF8), + options, + cancellationToken: cancellationToken + ) + ); + + // get the newly bound attribute + var preserveAttributeSymbol = compilation.GetTypeByMetadataName( + $"{refitInternalNamespace}.PreserveAttribute" + )!; + + var preserveAttributeDisplayName = preserveAttributeSymbol.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + + var interfaceModels = new List(); + // group the fields by interface and generate the source + foreach (var group in interfaces) + { + cancellationToken.ThrowIfCancellationRequested(); + + // each group is keyed by the Interface INamedTypeSymbol and contains the members + // with a refit attribute on them. Types may contain other members, without the attribute, which we'll + // need to check for and error out on + var keyName = group.Key.Name; + int value; + while (keyCount.TryGetValue(keyName, out value)) + { + keyName = $"{keyName}{++value}"; + } + keyCount[keyName] = value; + var fileName = $"{keyName}.g.cs"; + + var interfaceModel = ProcessInterface( + fileName, + diagnostics, + group.Key, + group.Value, + preserveAttributeDisplayName, + disposableInterfaceSymbol, + httpMethodBaseAttributeSymbol, + supportsNullable, + interfaceToNullableEnabledMap[group.Key] + ); + + interfaceModels.Add(interfaceModel); + } + + var contextGenerationSpec = new ContextGenerationModel( + refitInternalNamespace, + preserveAttributeDisplayName, + interfaceModels.ToImmutableEquatableArray() + ); + return (diagnostics, contextGenerationSpec); + } + + static InterfaceModel ProcessInterface( + string fileName, + List diagnostics, + INamedTypeSymbol interfaceSymbol, + List refitMethods, + string preserveAttributeDisplayName, + ISymbol disposableInterfaceSymbol, + INamedTypeSymbol httpMethodBaseAttributeSymbol, + bool supportsNullable, + bool nullableEnabled + ) + { + // Get the class name with the type parameters, then remove the namespace + var className = interfaceSymbol.ToDisplayString(); + var lastDot = className.LastIndexOf('.'); + if (lastDot > 0) + { + className = className.Substring(lastDot + 1); + } + var classDeclaration = $"{interfaceSymbol.ContainingType?.Name}{className}"; + + // Get the class name itself + var classSuffix = $"{interfaceSymbol.ContainingType?.Name}{interfaceSymbol.Name}"; + var ns = interfaceSymbol.ContainingNamespace?.ToDisplayString(); + + // if it's the global namespace, our lookup rules say it should be the same as the class name + if (interfaceSymbol.ContainingNamespace is { IsGlobalNamespace: true }) + { + ns = string.Empty; + } + + // Remove dots + ns = ns!.Replace(".", ""); + var interfaceDisplayName = interfaceSymbol.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + + // Get any other methods on the refit interfaces. We'll need to generate something for them and warn + var nonRefitMethods = interfaceSymbol + .GetMembers() + .OfType() + .Except(refitMethods, SymbolEqualityComparer.Default) + .Cast() + .ToArray(); + + // get methods for all inherited + var derivedMethods = interfaceSymbol + .AllInterfaces.SelectMany(i => i.GetMembers().OfType()) + .ToList(); + + // Look for disposable + var disposeMethod = derivedMethods.Find( + m => + m.ContainingType?.Equals(disposableInterfaceSymbol, SymbolEqualityComparer.Default) + == true + ); + if (disposeMethod != null) + { + //remove it from the derived methods list so we don't process it with the rest + derivedMethods.Remove(disposeMethod); + } + + // Pull out the refit methods from the derived types + var derivedRefitMethods = derivedMethods + .Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)) + .ToArray(); + var derivedNonRefitMethods = derivedMethods + .Except(derivedMethods, SymbolEqualityComparer.Default) + .Cast() + .ToArray(); + + var memberNames = interfaceSymbol + .GetMembers() + .Select(x => x.Name) + .Distinct() + .ToImmutableEquatableArray(); + + // Handle Refit Methods + var refitMethodsArray = refitMethods + .Select(m => ParseMethod(m, true)) + .ToImmutableEquatableArray(); + var derivedRefitMethodsArray = refitMethods + .Concat(derivedRefitMethods) + .Select(m => ParseMethod(m, false)) + .ToImmutableEquatableArray(); + + // Handle non-refit Methods that aren't static or properties or have a method body + var nonRefitMethodModelList = new List(); + foreach (var method in nonRefitMethods.Concat(derivedNonRefitMethods)) + { + if ( + method.IsStatic + || method.MethodKind == MethodKind.PropertyGet + || method.MethodKind == MethodKind.PropertySet + || !method.IsAbstract + ) // If an interface method has a body, it won't be abstract + continue; + + nonRefitMethodModelList.Add(ParseNonRefitMethod(method, diagnostics)); + } + + var nonRefitMethodModels = nonRefitMethodModelList.ToImmutableEquatableArray(); + + var constraints = GenerateConstraints(interfaceSymbol.TypeParameters, false); + var hasDispose = disposeMethod != null; + var nullability = (supportsNullable, nullableEnabled) switch + { + (false, _) => Nullability.None, + (true, true) => Nullability.Enabled, + (true, false) => Nullability.Disabled, + }; + return new InterfaceModel( + preserveAttributeDisplayName, + fileName, + className, + ns, + classDeclaration, + interfaceDisplayName, + classSuffix, + constraints, + memberNames, + nonRefitMethodModels, + refitMethodsArray, + derivedRefitMethodsArray, + nullability, + hasDispose + ); + } + + private static MethodModel ParseNonRefitMethod( + IMethodSymbol methodSymbol, + List diagnostics + ) + { + // report invalid error diagnostic + foreach (var location in methodSymbol.Locations) + { + var diagnostic = Diagnostic.Create( + DiagnosticDescriptors.InvalidRefitMember, + location, + methodSymbol.ContainingType.Name, + methodSymbol.Name + ); + diagnostics.Add(diagnostic); + } + + return ParseMethod(methodSymbol, false); + } + + private static bool IsRefitMethod( + IMethodSymbol? methodSymbol, + INamedTypeSymbol httpMethodAttribute + ) + { + return methodSymbol + ?.GetAttributes() + .Any(ad => ad.AttributeClass?.InheritsFromOrEquals(httpMethodAttribute) == true) + == true; + } + + private static ImmutableEquatableArray GenerateConstraints( + ImmutableArray typeParameters, + bool isOverrideOrExplicitImplementation + ) + { + // Need to loop over the constraints and create them + return typeParameters + .Select( + typeParameter => + ParseConstraintsForTypeParameter( + typeParameter, + isOverrideOrExplicitImplementation + ) + ) + .ToImmutableEquatableArray(); + } + + private static TypeConstraint ParseConstraintsForTypeParameter( + ITypeParameterSymbol typeParameter, + bool isOverrideOrExplicitImplementation + ) + { + // Explicit interface implementations and overrides can only have class or struct constraints + var known = KnownTypeConstraint.None; + + if (typeParameter.HasReferenceTypeConstraint) + { + known |= KnownTypeConstraint.Class; + } + if (typeParameter.HasUnmanagedTypeConstraint && !isOverrideOrExplicitImplementation) + { + known |= KnownTypeConstraint.Unmanaged; + } + + // unmanaged constraints are both structs and unmanaged so the struct constraint is redundant + if (typeParameter.HasValueTypeConstraint && !typeParameter.HasUnmanagedTypeConstraint) + { + known |= KnownTypeConstraint.Struct; + } + if (typeParameter.HasNotNullConstraint && !isOverrideOrExplicitImplementation) + { + known |= KnownTypeConstraint.NotNull; + } + + var constraints = ImmutableEquatableArray.Empty; + if (!isOverrideOrExplicitImplementation) + { + constraints = typeParameter + .ConstraintTypes.Select( + typeConstraint => + typeConstraint.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + ) + .ToImmutableEquatableArray(); + } + + // new constraint has to be last + if (typeParameter.HasConstructorConstraint && !isOverrideOrExplicitImplementation) + { + known |= KnownTypeConstraint.New; + } + + var declaredName = typeParameter.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + return new TypeConstraint(typeParameter.Name, declaredName, known, constraints); + } + + private static ParameterModel ParseParameter(IParameterSymbol param) + { + var annotation = + !param.Type.IsValueType && param.NullableAnnotation == NullableAnnotation.Annotated; + + var paramType = param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var isGeneric = ContainsTypeParameter(param.Type); + + return new ParameterModel(param.MetadataName, paramType, annotation, isGeneric); + } + + private static bool ContainsTypeParameter(ITypeSymbol symbol) + { + if (symbol is ITypeParameterSymbol) + return true; + + if (symbol is not INamedTypeSymbol { TypeParameters.Length: > 0 } namedType) + return false; + + foreach (var typeArg in namedType.TypeArguments) + { + if (ContainsTypeParameter(typeArg)) + return true; + } + + return false; + } + + private static MethodModel ParseMethod(IMethodSymbol methodSymbol, bool isImplicitInterface) + { + var returnType = methodSymbol.ReturnType.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + + var containingType = methodSymbol.ContainingType.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + var declaredMethod = methodSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeInfo = methodSymbol.ReturnType.MetadataName switch + { + "Task" => ReturnTypeInfo.AsyncVoid, + "Task`1" or "ValueTask`1" => ReturnTypeInfo.AsyncResult, + _ => ReturnTypeInfo.Return, + }; + + var parameters = methodSymbol.Parameters.Select(ParseParameter).ToImmutableEquatableArray(); + + var constraints = GenerateConstraints(methodSymbol.TypeParameters, !isImplicitInterface); + + return new MethodModel( + methodSymbol.Name, + returnType, + containingType, + declaredMethod, + returnTypeInfo, + parameters, + constraints + ); + } +} diff --git a/Refit.GeneratorTests/Incremental/FunctionTest.cs b/Refit.GeneratorTests/Incremental/FunctionTest.cs index 903a93fda..39771bde0 100644 --- a/Refit.GeneratorTests/Incremental/FunctionTest.cs +++ b/Refit.GeneratorTests/Incremental/FunctionTest.cs @@ -6,6 +6,7 @@ public class FunctionTest { private const string DefaultInterface = """ + #nullable enabled using System; using System.Collections.Generic; using System.Linq; @@ -24,7 +25,28 @@ public interface IGitHubApi } """; - // [Fact] + private const string ReturnValueInterface = + """ + #nullable enabled + using System; + using System.Collections.Generic; + using System.Linq; + using System.Net.Http; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using Refit; + + namespace RefitGeneratorTest; + + public interface IGitHubApi + { + [Get("/users/{user}")] + Task GetUser(string user); + } + """; + + [Fact] public void ModifyParameterNameDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -48,7 +70,7 @@ public interface IGitHubApi TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void ModifyParameterTypeDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -72,7 +94,7 @@ public interface IGitHubApi TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void ModifyParameterNullabilityDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -96,7 +118,7 @@ public interface IGitHubApi TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void AddParameterDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -120,7 +142,7 @@ public interface IGitHubApi TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void ModifyReturnTypeDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -144,8 +166,8 @@ public interface IGitHubApi TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] - public void ModifyReturnNullabilityDoesRegenerate() + [Fact] + public void ModifyReturnObjectNullabilityDoesNotRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); var compilation1 = Fixture.CreateLibrary(syntaxTree); @@ -164,11 +186,35 @@ public interface IGitHubApi """; var compilation2 = TestHelper.ReplaceMemberDeclaration(compilation1, "IGitHubApi", newInterface); + var driver2 = driver1.RunGenerators(compilation2); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached); + } + + [Fact] + public void ModifyReturnValueNullabilityDoesRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + TestHelper.AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New); + + // change return nullability + var newInterface = + """ + public interface IGitHubApi + { + [Get("/users/{user}")] + Task GetUser(string user); + } + """; + var compilation2 = TestHelper.ReplaceMemberDeclaration(compilation1, "IGitHubApi", newInterface); + var driver2 = driver1.RunGenerators(compilation2); TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void AddNonRefitMethodDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); diff --git a/Refit.GeneratorTests/Incremental/GenericTest.cs b/Refit.GeneratorTests/Incremental/GenericTest.cs index 131fb20bd..f09cf2555 100644 --- a/Refit.GeneratorTests/Incremental/GenericTest.cs +++ b/Refit.GeneratorTests/Incremental/GenericTest.cs @@ -24,7 +24,7 @@ public interface IGeneratedInterface } """; - // [Fact] + [Fact] public void RenameGenericTypeDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(GenericInterface, CSharpParseOptions.Default); @@ -49,7 +49,7 @@ public interface IGeneratedInterface TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void AddGenericConstraintDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(GenericInterface, CSharpParseOptions.Default); @@ -91,7 +91,7 @@ public interface IGeneratedInterface TestHelper.AssertRunReasons(driver3, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void AddObjectGenericConstraintDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(GenericInterface, CSharpParseOptions.Default); @@ -117,7 +117,7 @@ public interface IGeneratedInterface TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void AddGenericTypeDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(GenericInterface, CSharpParseOptions.Default); diff --git a/Refit.GeneratorTests/Incremental/IncrementalGeneratorRunReasons.cs b/Refit.GeneratorTests/Incremental/IncrementalGeneratorRunReasons.cs index 473762261..cc1cac5f5 100644 --- a/Refit.GeneratorTests/Incremental/IncrementalGeneratorRunReasons.cs +++ b/Refit.GeneratorTests/Incremental/IncrementalGeneratorRunReasons.cs @@ -14,7 +14,7 @@ IncrementalStepRunReason ReportDiagnosticsStep new( // compilation step should always be modified as each time a new compilation is passed IncrementalStepRunReason.Cached, - IncrementalStepRunReason.Cached + IncrementalStepRunReason.Unchanged ); public static readonly IncrementalGeneratorRunReasons Modified = Cached with diff --git a/Refit.GeneratorTests/Incremental/IncrementalTest.cs b/Refit.GeneratorTests/Incremental/IncrementalTest.cs index 88ba56c5a..c6f576974 100644 --- a/Refit.GeneratorTests/Incremental/IncrementalTest.cs +++ b/Refit.GeneratorTests/Incremental/IncrementalTest.cs @@ -6,6 +6,7 @@ public class IncrementalTest { private const string DefaultInterface = """ + #nullable enabled using System; using System.Collections.Generic; using System.Linq; @@ -24,7 +25,7 @@ public interface IGitHubApi } """; - // [Fact] + [Fact] public void AddUnrelatedTypeDoesntRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -38,7 +39,7 @@ public void AddUnrelatedTypeDoesntRegenerate() TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached); } - // [Fact] + [Fact] public void SmallChangeDoesntRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -63,7 +64,7 @@ public interface IGitHubApi TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached); } - // [Fact] + [Fact] public void AddNewMemberDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); diff --git a/Refit.GeneratorTests/Incremental/InheritanceTest.cs b/Refit.GeneratorTests/Incremental/InheritanceTest.cs index 688cd7fc7..50f9de48e 100644 --- a/Refit.GeneratorTests/Incremental/InheritanceTest.cs +++ b/Refit.GeneratorTests/Incremental/InheritanceTest.cs @@ -46,7 +46,7 @@ public interface IGitHubApi public interface IBaseInterface { void NonRefitMethod(); } """; - // [Fact] + [Fact] public void InheritFromIDisposableDoesRegenerate() { var syntaxTree = CSharpSyntaxTree.ParseText(DefaultInterface, CSharpParseOptions.Default); @@ -71,9 +71,10 @@ public interface IGitHubApi : IDisposable TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.ModifiedSource); } - // [Fact] + [Fact] public void InheritFromInterfaceDoesRegenerate() { + // TODO: this currently generates invalid code see issue #1801 for more information var syntaxTree = CSharpSyntaxTree.ParseText(TwoInterface, CSharpParseOptions.Default); var compilation1 = Fixture.CreateLibrary(syntaxTree); @@ -93,6 +94,6 @@ public interface IGitHubApi : IBaseInterface """ ); var driver2 = driver1.RunGenerators(compilation2); - TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Modified); + TestHelper.AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached); } } diff --git a/Refit.GeneratorTests/InterfaceTests.cs b/Refit.GeneratorTests/InterfaceTests.cs index 92e462663..74a1c0195 100644 --- a/Refit.GeneratorTests/InterfaceTests.cs +++ b/Refit.GeneratorTests/InterfaceTests.cs @@ -40,7 +40,7 @@ public interface IBaseInterface [Fact] public Task RefitInterfaceDerivedFromBaseTest() { - // this currently generates invalid code see issue #1801 for more information + // TODO: this currently generates invalid code see issue #1801 for more information return Fixture.VerifyForType( """ public interface IGeneratedInterface : IBaseInterface