diff --git a/osu.Framework.SourceGeneration/Generators/Dependencies/DependencyInjectionSourceGenerator.cs b/osu.Framework.SourceGeneration/Generators/Dependencies/DependencyInjectionSourceGenerator.cs index 266933c700..be94dee330 100644 --- a/osu.Framework.SourceGeneration/Generators/Dependencies/DependencyInjectionSourceGenerator.cs +++ b/osu.Framework.SourceGeneration/Generators/Dependencies/DependencyInjectionSourceGenerator.cs @@ -7,7 +7,6 @@ namespace osu.Framework.SourceGeneration.Generators.Dependencies { - [Generator] public class DependencyInjectionSourceGenerator : AbstractIncrementalGenerator { protected override IncrementalSemanticTarget CreateSemanticTarget(ClassDeclarationSyntax node, SemanticModel semanticModel) diff --git a/osu.Framework.SourceGeneration/Generators/Dependencies/NewDependencyInjectionSourceGenerator.cs b/osu.Framework.SourceGeneration/Generators/Dependencies/NewDependencyInjectionSourceGenerator.cs new file mode 100644 index 0000000000..58bff05328 --- /dev/null +++ b/osu.Framework.SourceGeneration/Generators/Dependencies/NewDependencyInjectionSourceGenerator.cs @@ -0,0 +1,182 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using osu.Framework.SourceGeneration.Generators.Dependencies.Emitters; + +namespace osu.Framework.SourceGeneration.Generators.Dependencies +{ + [Generator] + public class NewDependencyInjectionSourceGenerator : IIncrementalGenerator + { + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // All interfaces that have a [Cached] attribute. + // IncrementalValuesProvider cachedInterfaces = + // context.SyntaxProvider + // .CreateSyntaxProvider( + // (n, _) => n.IsKind(SyntaxKind.InterfaceDeclaration), + // (ctx, _) => new DependencyInjectionSemanticTarget(ctx)) + // .Where(target => target.AnyAttributes()); + + // Classes containing [Cached], [Resolved], or [BackgroundDependencyLoader] attributes. + IncrementalValuesProvider candidates = + context.SyntaxProvider + .CreateSyntaxProvider( + (n, _) => n.IsKind(SyntaxKind.ClassDeclaration), + (ctx, _) => new ActivatorCandidate(ctx)) + .Where(candidate => candidate.HasAnyAttributes); + + // Classes with semantic information. + IncrementalValuesProvider semanticCandidates = + candidates.Select((candidate, _) => new DependenciesClassCandidate((ClassDeclarationSyntax)candidate.Context.Node, candidate.Context.SemanticModel)); + + context.RegisterImplementationSourceOutput( + semanticCandidates, + (ctx, candidate) => new DependenciesFileEmitter(candidate).Emit(ctx.AddSource)); + } + + private readonly struct ActivatorCandidate : IEquatable + { + public readonly GeneratorSyntaxContext Context; + + private readonly ClassDeclarationSyntax classSyntax; + private readonly List cachedMembers = new List(); + private readonly List resolvedMembers = new List(); + private readonly List loaderMembers = new List(); + + public ActivatorCandidate(GeneratorSyntaxContext context) + { + Context = context; + classSyntax = (ClassDeclarationSyntax)context.Node; + + foreach (var attrib in NewSyntaxHelpers.EnumerateNamedAttributes(classSyntax, "Cached")) + cachedMembers.Add(new CachedMemberInfo(attrib, classSyntax)); + + foreach (var member in classSyntax.Members) + { + switch (member) + { + case PropertyDeclarationSyntax property: + foreach (var attrib in NewSyntaxHelpers.EnumerateNamedAttributes(property, "Cached")) + cachedMembers.Add(new CachedMemberInfo(attrib, property)); + foreach (var attrib in NewSyntaxHelpers.EnumerateNamedAttributes(property, "Resolved")) + resolvedMembers.Add(new ResolvedMemberInfo(attrib, property)); + break; + + case FieldDeclarationSyntax field: + foreach (var attrib in NewSyntaxHelpers.EnumerateNamedAttributes(field, "Cached")) + cachedMembers.Add(new CachedMemberInfo(attrib, field)); + break; + + case MethodDeclarationSyntax method: + foreach (var attrib in NewSyntaxHelpers.EnumerateNamedAttributes(method, "BackgroundDependencyLoader")) + loaderMembers.Add(new LoaderMemberInfo(attrib, method)); + break; + } + } + } + + public bool HasAnyAttributes => cachedMembers.Any() || resolvedMembers.Any() || loaderMembers.Any(); + + public bool Equals(ActivatorCandidate other) + => classSyntax.Identifier.IsEquivalentTo(other.classSyntax.Identifier) + && cachedMembers.SequenceEqual(other.cachedMembers) + && resolvedMembers.SequenceEqual(other.resolvedMembers) + && loaderMembers.SequenceEqual(other.loaderMembers); + } + + private readonly struct CachedMemberInfo : IEquatable + { + private readonly AttributeSyntax attributeSyntax; + private readonly MemberDeclarationSyntax memberSyntax; + + public CachedMemberInfo(AttributeSyntax attributeSyntax, MemberDeclarationSyntax memberSyntax) + { + this.attributeSyntax = attributeSyntax; + this.memberSyntax = memberSyntax; + } + + public bool Equals(CachedMemberInfo other) + => attributeSyntax.IsEquivalentTo(other.attributeSyntax) + && memberSyntax switch + { + ClassDeclarationSyntax c => c.Identifier.IsEquivalentTo(((ClassDeclarationSyntax)other.memberSyntax).Identifier), + InterfaceDeclarationSyntax i => i.Identifier.IsEquivalentTo(((InterfaceDeclarationSyntax)other.memberSyntax).Identifier), + PropertyDeclarationSyntax p => p.Identifier.IsEquivalentTo(((PropertyDeclarationSyntax)other.memberSyntax).Identifier) + && p.Type.IsEquivalentTo(((PropertyDeclarationSyntax)other.memberSyntax).Type), + FieldDeclarationSyntax f => f.Declaration.IsEquivalentTo(((FieldDeclarationSyntax)other.memberSyntax).Declaration), + _ => false + }; + } + + private readonly struct ResolvedMemberInfo : IEquatable + { + private readonly AttributeSyntax attributeSyntax; + private readonly PropertyDeclarationSyntax propertySyntax; + + public ResolvedMemberInfo(AttributeSyntax attributeSyntax, PropertyDeclarationSyntax propertySyntax) + { + this.attributeSyntax = attributeSyntax; + this.propertySyntax = propertySyntax; + } + + public bool Equals(ResolvedMemberInfo other) + => attributeSyntax.IsEquivalentTo(other.attributeSyntax) + && propertySyntax.IsEquivalentTo(other.propertySyntax); + } + + private readonly struct LoaderMemberInfo : IEquatable + { + private readonly AttributeSyntax attributeSyntax; + private readonly MethodDeclarationSyntax methodSyntax; + + public LoaderMemberInfo(AttributeSyntax attributeSyntax, MethodDeclarationSyntax methodSyntax) + { + this.attributeSyntax = attributeSyntax; + this.methodSyntax = methodSyntax; + } + + public bool Equals(LoaderMemberInfo other) + => attributeSyntax.IsEquivalentTo(other.attributeSyntax) + && methodSyntax.Identifier.IsEquivalentTo(other.methodSyntax.Identifier) + && methodSyntax.ParameterList.IsEquivalentTo(other.methodSyntax.ParameterList); + } + } + + public static class NewSyntaxHelpers + { + public static string GetUnqualifiedName(NameSyntax name) + { + return name switch + { + IdentifierNameSyntax identifier => identifier.Identifier.ValueText, + AliasQualifiedNameSyntax alias => alias.Name.Identifier.ValueText, + QualifiedNameSyntax qualified => qualified.Right.Identifier.ValueText, + SimpleNameSyntax simple => simple.Identifier.ValueText, + _ => throw new ArgumentException("Unexpected name syntax.", nameof(name)) + }; + } + + public static IEnumerable EnumerateNamedAttributes(MemberDeclarationSyntax syntax, string name) + { + foreach (var list in syntax.AttributeLists) + { + foreach (var attribute in list.Attributes) + { + string attribName = GetUnqualifiedName(attribute.Name); + + // Note that this is somewhat "wide" for brevity, because any time we see "Cached", "Resolved", or "BackgroundDependencyLoader", + // it's generally going to be one of our own attributes (which otherwise cover 95% of classes anyway). + if (attribName.StartsWith(name, StringComparison.Ordinal)) + yield return attribute; + } + } + } + } +} diff --git a/osu.Framework.SourceGeneration/Generators/HandleInput/HandleInputSourceGenerator.cs b/osu.Framework.SourceGeneration/Generators/HandleInput/HandleInputSourceGenerator.cs index c6bf6145b5..239760d9e7 100644 --- a/osu.Framework.SourceGeneration/Generators/HandleInput/HandleInputSourceGenerator.cs +++ b/osu.Framework.SourceGeneration/Generators/HandleInput/HandleInputSourceGenerator.cs @@ -6,7 +6,6 @@ namespace osu.Framework.SourceGeneration.Generators.HandleInput { - [Generator] public class HandleInputSourceGenerator : AbstractIncrementalGenerator { protected override IncrementalSemanticTarget CreateSemanticTarget(ClassDeclarationSyntax node, SemanticModel semanticModel) diff --git a/osu.Framework.SourceGeneration/Generators/IncrementalSemanticTarget.cs b/osu.Framework.SourceGeneration/Generators/IncrementalSemanticTarget.cs index f4197579de..ab25504968 100644 --- a/osu.Framework.SourceGeneration/Generators/IncrementalSemanticTarget.cs +++ b/osu.Framework.SourceGeneration/Generators/IncrementalSemanticTarget.cs @@ -10,8 +10,6 @@ namespace osu.Framework.SourceGeneration.Generators { public abstract class IncrementalSemanticTarget { - public readonly ClassDeclarationSyntax ClassSyntax; - public readonly string FullyQualifiedTypeName = string.Empty; public readonly string GlobalPrefixedTypeName = string.Empty; public readonly bool NeedsOverride; @@ -22,9 +20,7 @@ public abstract class IncrementalSemanticTarget protected IncrementalSemanticTarget(ClassDeclarationSyntax classSyntax, SemanticModel semanticModel) { - ClassSyntax = classSyntax; - - INamedTypeSymbol symbol = semanticModel.GetDeclaredSymbol(ClassSyntax)!; + INamedTypeSymbol symbol = semanticModel.GetDeclaredSymbol(classSyntax)!; IsValid = CheckValid(symbol); diff --git a/osu.Framework.SourceGeneration/Generators/LongRunningLoad/LongRunningLoadSourceGenerator.cs b/osu.Framework.SourceGeneration/Generators/LongRunningLoad/LongRunningLoadSourceGenerator.cs index 62ebb782ce..cce9dd8e17 100644 --- a/osu.Framework.SourceGeneration/Generators/LongRunningLoad/LongRunningLoadSourceGenerator.cs +++ b/osu.Framework.SourceGeneration/Generators/LongRunningLoad/LongRunningLoadSourceGenerator.cs @@ -6,7 +6,6 @@ namespace osu.Framework.SourceGeneration.Generators.LongRunningLoad { - [Generator] public class LongRunningLoadSourceGenerator : AbstractIncrementalGenerator { protected override IncrementalSemanticTarget CreateSemanticTarget(ClassDeclarationSyntax node, SemanticModel semanticModel) diff --git a/osu.Framework.SourceGeneration/HashCode.cs b/osu.Framework.SourceGeneration/HashCode.cs new file mode 100644 index 0000000000..58b7b0abb7 --- /dev/null +++ b/osu.Framework.SourceGeneration/HashCode.cs @@ -0,0 +1,190 @@ +// ReSharper disable All + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.ComponentModel; +using System.Runtime.CompilerServices; +using System.Security.Cryptography; + +#pragma warning disable CS0809 + +namespace System; + +/// +/// A polyfill type that mirrors some methods from on .NET 6. +/// +internal struct HashCode +{ + private const uint Prime1 = 2654435761U; + private const uint Prime2 = 2246822519U; + private const uint Prime3 = 3266489917U; + private const uint Prime4 = 668265263U; + private const uint Prime5 = 374761393U; + + private static readonly uint seed = GenerateGlobalSeed(); + + private uint v1, v2, v3, v4; + private uint queue1, queue2, queue3; + private uint length; + + /// + /// Initializes the default seed. + /// + /// A random seed. + private static unsafe uint GenerateGlobalSeed() + { + byte[] bytes = new byte[4]; + + using (RandomNumberGenerator generator = RandomNumberGenerator.Create()) + { + generator.GetBytes(bytes); + } + + return BitConverter.ToUInt32(bytes, 0); + } + + /// + /// Adds a single value to the current hash. + /// + /// The type of the value to add into the hash code. + /// The value to add into the hash code. + public void Add(T value) + { + Add(value?.GetHashCode() ?? 0); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void Initialize(out uint v1, out uint v2, out uint v3, out uint v4) + { + v1 = seed + Prime1 + Prime2; + v2 = seed + Prime2; + v3 = seed; + v4 = seed - Prime1; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint Round(uint hash, uint input) + { + return RotateLeft(hash + input * Prime2, 13) * Prime1; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint QueueRound(uint hash, uint queuedValue) + { + return RotateLeft(hash + queuedValue * Prime3, 17) * Prime4; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint MixState(uint v1, uint v2, uint v3, uint v4) + { + return RotateLeft(v1, 1) + RotateLeft(v2, 7) + RotateLeft(v3, 12) + RotateLeft(v4, 18); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint MixEmptyState() + { + return seed + Prime5; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint MixFinal(uint hash) + { + hash ^= hash >> 15; + hash *= Prime2; + hash ^= hash >> 13; + hash *= Prime3; + hash ^= hash >> 16; + + return hash; + } + + private void Add(int value) + { + uint val = (uint)value; + uint previousLength = this.length++; + uint position = previousLength % 4; + + if (position == 0) + { + this.queue1 = val; + } + else if (position == 1) + { + this.queue2 = val; + } + else if (position == 2) + { + this.queue3 = val; + } + else + { + if (previousLength == 3) + { + Initialize(out this.v1, out this.v2, out this.v3, out this.v4); + } + + this.v1 = Round(this.v1, this.queue1); + this.v2 = Round(this.v2, this.queue2); + this.v3 = Round(this.v3, this.queue3); + this.v4 = Round(this.v4, val); + } + } + + /// + /// Gets the resulting hashcode from the current instance. + /// + /// The resulting hashcode from the current instance. + public int ToHashCode() + { + uint length = this.length; + uint position = length % 4; + uint hash = length < 4 ? MixEmptyState() : MixState(this.v1, this.v2, this.v3, this.v4); + + hash += length * 4; + + if (position > 0) + { + hash = QueueRound(hash, this.queue1); + + if (position > 1) + { + hash = QueueRound(hash, this.queue2); + + if (position > 2) + { + hash = QueueRound(hash, this.queue3); + } + } + } + + hash = MixFinal(hash); + + return (int)hash; + } + + /// + [Obsolete("HashCode is a mutable struct and should not be compared with other HashCodes. Use ToHashCode to retrieve the computed hash code.", error: true)] + [EditorBrowsable(EditorBrowsableState.Never)] + public override int GetHashCode() => throw new NotSupportedException(); + + /// + [Obsolete("HashCode is a mutable struct and should not be compared with other HashCodes.", error: true)] + [EditorBrowsable(EditorBrowsableState.Never)] + public override bool Equals(object? obj) => throw new NotSupportedException(); + + /// + /// Rotates the specified value left by the specified number of bits. + /// Similar in behavior to the x86 instruction ROL. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..31] is treated as congruent mod 32. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint RotateLeft(uint value, int offset) + { + return (value << offset) | (value >> (32 - offset)); + } +} diff --git a/osu.Framework.SourceGeneration/osu.Framework.SourceGeneration.csproj b/osu.Framework.SourceGeneration/osu.Framework.SourceGeneration.csproj index 78337270a2..d33615d539 100644 --- a/osu.Framework.SourceGeneration/osu.Framework.SourceGeneration.csproj +++ b/osu.Framework.SourceGeneration/osu.Framework.SourceGeneration.csproj @@ -1,7 +1,6 @@  netstandard2.0 - 9.0 true true true @@ -9,6 +8,7 @@ false false 1591 + true osu!framework Source Generators diff --git a/osu.Framework/osu.Framework.csproj b/osu.Framework/osu.Framework.csproj index d8ca3abe08..422d9de9c8 100644 --- a/osu.Framework/osu.Framework.csproj +++ b/osu.Framework/osu.Framework.csproj @@ -39,7 +39,7 @@ - +