Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify resolution of dependency injection candidates #5548

Merged
merged 8 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions osu.Framework.Benchmarks/BenchmarkDependencyInjection.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

using System.Diagnostics.CodeAnalysis;
using BenchmarkDotNet.Attributes;
using osu.Framework.Allocation;

Expand Down Expand Up @@ -40,7 +41,8 @@ public void TestWithSourceGenerator()
}
}

public class ClassInjectedWithReflection
[SuppressMessage("Performance", "OFSG001:Class contributes to dependency injection and should be partial")]
public class ClassInjectedWithReflection : IDependencyInjectionCandidate
{
// ReSharper disable once UnusedAutoPropertyAccessor.Local
[Resolved]
Expand All @@ -49,7 +51,7 @@ public class ClassInjectedWithReflection

// This inspection can be removed once the source generator is merged in/referenced as a package.
// ReSharper disable once PartialTypeWithSinglePart
public partial class ClassInjectedWithSourceGenerator
public partial class ClassInjectedWithSourceGenerator : IDependencyInjectionCandidate
{
// ReSharper disable once UnusedAutoPropertyAccessor.Local
[Resolved]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ public class DependencyInjectionSourceGeneratorTests : AbstractGeneratorTests
[InlineData("NestedCachedClass")]
[InlineData("MultipleCachedMember")]
[InlineData("CachedInheritedInterface")]
// Todo: Fix this.
// [InlineData("CachedBaseType")]
[InlineData("CachedBaseType")]
public async Task Check(string name) => await RunTest(name).ConfigureAwait(false);

protected override Task Verify((string filename, string content)[] sources, (string filename, string content)[] generated)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

partial class DerivedType : osu.Framework.Allocation.ISourceGeneratedDependencyActivator
{
public virtual void RegisterForDependencyActivation(osu.Framework.Allocation.IDependencyActivatorRegistry registry)
public override void RegisterForDependencyActivation(osu.Framework.Allocation.IDependencyActivatorRegistry registry)
{
if (registry.IsRegistered(typeof(DerivedType)))
return;
base.RegisterForDependencyActivation(registry);
registry.Register(typeof(DerivedType), null, null);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[osu.Framework.Allocation.Cached]
public partial class BaseType
public partial class BaseType : osu.Framework.Allocation.IDependencyInjectionCandidate
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ namespace osu.Framework.Allocation

public object Get(Type type, CacheInfo info) => default;

public void Inject<T>(T instance) where T : class { }
public void Inject<T>(T instance) where T : class, IDependencyInjectionCandidate { }
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace osu.Framework.Graphics
{
public partial class Drawable : IDrawable
public partial class Drawable : osu.Framework.Allocation.IDependencyInjectionCandidate
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace osu.Framework.Allocation
{
public interface IDependencyInjectionCandidate
{
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace osu.Framework.Allocation
{
object Get(Type type);
object Get(Type type, CacheInfo info);
void Inject<T>(T instance) where T : class;
void Inject<T>(T instance) where T : class, IDependencyInjectionCandidate;
}

public static class ReadOnlyDependencyContainerExtensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace osu.Framework.Utils
{
public static class SourceGeneratorUtils
{
public static void CacheDependency(DependencyContainer container, Type callerType, object obj, CacheInfo info, Type? asType, string? cachedName, string? propertyName)
public static void CacheDependency(DependencyContainer container, Type callerType, object? obj, CacheInfo info, Type? asType, string? cachedName, string? propertyName)
{
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[osu.Framework.Allocation.Cached]
public partial class A
public partial class A : osu.Framework.Allocation.IDependencyInjectionCandidate
{
}
6 changes: 3 additions & 3 deletions osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ public class DiagnosticRules

public static readonly DiagnosticDescriptor MAKE_DI_CLASS_PARTIAL = new DiagnosticDescriptor(
"OFSG001",
"Class contributes to dependency injection and should be partial",
"Class contributes to dependency injection and should be partial",
"This class is a candidate for dependency injection and should be partial",
"This class is a candidate for dependency injection and should be partial",
"Performance",
DiagnosticSeverity.Warning,
true,
"Classes contributing to dependency injection should be made partial to be subject to compile-time optimisations. This includes usages of `DependencyActivator` and `CachedModelDependencyContainer`.");
"Classes that are candidates for dependency injection should be made partial to benefit from compile-time optimisations.");

#pragma warning restore RS2008
}
Expand Down
102 changes: 1 addition & 101 deletions osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,88 +21,6 @@ public override void Initialize(AnalysisContext context)
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.EnableConcurrentExecution();
context.RegisterSyntaxNodeAction(analyseClass, SyntaxKind.ClassDeclaration);
context.RegisterSyntaxNodeAction(analyseInvocation, SyntaxKind.InvocationExpression);
context.RegisterSyntaxNodeAction(analyseObjectCreation, SyntaxKind.ObjectCreationExpression);
}

/// <summary>
/// Analyses construction of CachedModelDependencyContainer{T}.
/// </summary>
private void analyseObjectCreation(SyntaxNodeAnalysisContext context)
{
var objectCreationSyntax = (ObjectCreationExpressionSyntax)context.Node;

GenericNameSyntax? genericName = objectCreationSyntax.Type as GenericNameSyntax;

if (objectCreationSyntax.Type is QualifiedNameSyntax qualified)
genericName = qualified.Right as GenericNameSyntax;

if (genericName == null)
return;

if (genericName.Identifier.ValueText != "CachedModelDependencyContainer")
return;

TypeSyntax? typeSyntax = genericName.TypeArgumentList.Arguments.FirstOrDefault();

if (typeSyntax == null)
return;

ITypeSymbol? argumentType = context.SemanticModel.GetTypeInfo(typeSyntax).Type;
SyntaxTree? argumentSyntaxTree = argumentType?.DeclaringSyntaxReferences.FirstOrDefault()?.SyntaxTree;
ClassDeclarationSyntax? argumentClassSyntax = argumentSyntaxTree?.GetRoot().DescendantNodesAndSelf()
.OfType<ClassDeclarationSyntax>()
.FirstOrDefault(c => c.Identifier.ValueText == argumentType?.Name);

if (argumentClassSyntax == null)
return;

if (argumentClassSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
return;

// Todo: Why doesn't this work for nested class? It _is_ getting here...
context.ReportDiagnostic(Diagnostic.Create(DiagnosticRules.MAKE_DI_CLASS_PARTIAL, typeSyntax.GetLocation(), typeSyntax));
}

/// <summary>
/// Analyses invocations of DependencyContainer.Inject{T}(T obj).
/// </summary>
private void analyseInvocation(SyntaxNodeAnalysisContext context)
{
var invocationSyntax = (InvocationExpressionSyntax)context.Node;

if (invocationSyntax.ArgumentList.Arguments.Count == 0)
return;

if (invocationSyntax.Expression is not MemberAccessExpressionSyntax memberAccessSyntax)
return;

if (memberAccessSyntax.Name.Identifier.ValueText != "Inject")
return;

ITypeSymbol? expressionType = context.SemanticModel.GetTypeInfo(memberAccessSyntax.Expression).Type;

if (expressionType == null)
return;

if (!SyntaxHelpers.IsIReadOnlyDependencyContainerInterface(expressionType) && !expressionType.AllInterfaces.Any(SyntaxHelpers.IsIReadOnlyDependencyContainerInterface))
return;

ExpressionSyntax argumentExpression = invocationSyntax.ArgumentList.Arguments[0].Expression;
ITypeSymbol? argumentType = context.SemanticModel.GetTypeInfo(argumentExpression).Type;
SyntaxTree? argumentSyntaxTree = argumentType?.DeclaringSyntaxReferences.FirstOrDefault()?.SyntaxTree;
ClassDeclarationSyntax? argumentClassSyntax = argumentSyntaxTree?.GetRoot().DescendantNodesAndSelf()
.OfType<ClassDeclarationSyntax>()
.FirstOrDefault(c => c.Identifier.ValueText == argumentType?.Name);

if (argumentClassSyntax == null)
return;

if (argumentClassSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
return;

// Todo: Why doesn't this work for nested class? It _is_ getting here...
context.ReportDiagnostic(Diagnostic.Create(DiagnosticRules.MAKE_DI_CLASS_PARTIAL, argumentExpression.GetLocation(), argumentExpression));
}

/// <summary>
Expand All @@ -117,26 +35,8 @@ private void analyseClass(SyntaxNodeAnalysisContext context)

INamedTypeSymbol? type = context.SemanticModel.GetDeclaredSymbol(classSyntax);

if (type != null && requiresPartialClass(type))
if (type?.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface) == true)
context.ReportDiagnostic(Diagnostic.Create(DiagnosticRules.MAKE_DI_CLASS_PARTIAL, context.Node.GetLocation(), context.Node));
}

private bool requiresPartialClass(ITypeSymbol type)
{
// "Transformable" is a special class below "Drawable" but still a part of the drawable hierarchy.
// It's the base-most type of all drawable objects, and so needs to be partial (see below).
if (SyntaxHelpers.IsTransformableType(type))
return true;

// "IDrawable" classes need to be partial since dependency injection happens implicitly through the drawable hierarchy.
if (type.AllInterfaces.Any(SyntaxHelpers.IsIDrawableInterface))
return true;

// "ISourceGeneratedDependencyActivatorInterface" classes need to be partial since their base type is used in dependency injection.
if (type.AllInterfaces.Any(SyntaxHelpers.IsISourceGeneratedDependencyActivatorInterface))
return true;

return false;
}
}
}
25 changes: 9 additions & 16 deletions osu.Framework.SourceGeneration/SyntaxContextReceiver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,24 @@ public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
return;

// Determine if the class is a candidate for the source generator.
// Classes may be candidates even if they don't resolve/cache anything themselves, but a base type does.
foreach (var iFace in symbol.AllInterfaces)
{
// All classes that derive from IDrawable need to use the source generator.
// This is conservative for all other (i.e. non-Drawable) classes to avoid polluting irrelevant classes.
if (SyntaxHelpers.IsIDrawableInterface(iFace) || SyntaxHelpers.IsITransformableInterface(iFace) || SyntaxHelpers.IsISourceGeneratedDependencyActivatorInterface(iFace))
{
addCandidate(context, classSyntax);
break;
}
}
if (!symbol.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface))
return;

GeneratorClassCandidate candidate = addCandidate(context, classSyntax);

// Process any [Cached] attributes on any interface on the class excluding base types.
foreach (var iFace in SyntaxHelpers.GetDeclaredInterfacesOnType(symbol))
{
// Add an entry if this interface has a cached attribute.
if (iFace.GetAttributes().Any(attrib => SyntaxHelpers.IsCachedAttribute(attrib.AttributeClass)))
addCandidate(context, classSyntax).CachedInterfaces.Add(iFace);
candidate.CachedInterfaces.Add(iFace);
}

// Process any [Cached] attributes on the class.
foreach (var attrib in enumerateAttributes(context.SemanticModel, classSyntax))
{
if (SyntaxHelpers.IsCachedAttribute(context.SemanticModel, attrib))
addCandidate(context, classSyntax).CachedClasses.Add(new SyntaxWithSymbol(context, classSyntax));
candidate.CachedClasses.Add(new SyntaxWithSymbol(context, classSyntax));
}

// Process any attributes of members of the class.
Expand All @@ -68,16 +61,16 @@ public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
foreach (var attrib in enumerateAttributes(context.SemanticModel, member))
{
if (SyntaxHelpers.IsBackgroundDependencyLoaderAttribute(context.SemanticModel, attrib))
addCandidate(context, classSyntax).DependencyLoaderMemebers.Add(new SyntaxWithSymbol(context, member));
candidate.DependencyLoaderMemebers.Add(new SyntaxWithSymbol(context, member));

if (member is not PropertyDeclarationSyntax && member is not FieldDeclarationSyntax)
continue;

if (SyntaxHelpers.IsResolvedAttribute(context.SemanticModel, attrib))
addCandidate(context, classSyntax).ResolvedMembers.Add(new SyntaxWithSymbol(context, member));
candidate.ResolvedMembers.Add(new SyntaxWithSymbol(context, member));

if (SyntaxHelpers.IsCachedAttribute(context.SemanticModel, attrib))
addCandidate(context, classSyntax).CachedMembers.Add(new SyntaxWithSymbol(context, member));
candidate.CachedMembers.Add(new SyntaxWithSymbol(context, member));
}
}
}
Expand Down
10 changes: 2 additions & 8 deletions osu.Framework.SourceGeneration/SyntaxHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,14 @@ public static bool IsResolvedAttribute(ITypeSymbol? type)
public static bool IsCachedAttribute(ITypeSymbol? type)
=> type?.Name == "CachedAttribute";

public static bool IsIDrawableInterface(ITypeSymbol? type)
=> type?.Name == "IDrawable";

public static bool IsITransformableInterface(ITypeSymbol? type)
=> type?.Name == "ITransformable";

public static bool IsISourceGeneratedDependencyActivatorInterface(ITypeSymbol? type)
=> type?.Name == "ISourceGeneratedDependencyActivator";

public static bool IsIReadOnlyDependencyContainerInterface(ITypeSymbol? type)
=> type?.Name == "IReadOnlyDependencyContainer";

public static bool IsTransformableType(ITypeSymbol? type)
=> type?.Name == "Transformable";
public static bool IsIDependencyInjectionCandidateInterface(ITypeSymbol? type)
=> type?.Name == "IDependencyInjectionCandidate";

public static IEnumerable<ITypeSymbol> EnumerateBaseTypes(ITypeSymbol type)
{
Expand Down
Loading