diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/BannedTypesAnalyzer.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/BannedTypesAnalyzer.cs index 6e989a07850..2c981a34a23 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/BannedTypesAnalyzer.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/BannedTypesAnalyzer.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; @@ -9,7 +10,7 @@ namespace Azure.ClientSdk.Analyzers { [DiagnosticAnalyzer(LanguageNames.CSharp)] - public sealed class BannedTypesAnalyzer : SymbolAnalyzerBase + public sealed class BannedTypesAnalyzer : DiagnosticAnalyzer { private static HashSet BannedTypes = new HashSet() { @@ -23,30 +24,25 @@ public sealed class BannedTypesAnalyzer : SymbolAnalyzerBase public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(Descriptors.AZC0020); - public override SymbolKind[] SymbolKinds { get; } = new[] + public SymbolKind[] SymbolKinds { get; } = new[] { SymbolKind.Event, SymbolKind.Field, - SymbolKind.Local, SymbolKind.Method, SymbolKind.NamedType, SymbolKind.Parameter, SymbolKind.Property, }; - // Note: suppressing warnings because they are handled in base.Initialize(). -#pragma warning disable RS1025 // Configure generated code analysis -#pragma warning disable RS1026 // Enable concurrent execution public override void Initialize(AnalysisContext context) -#pragma warning restore RS1026 // Enable concurrent execution -#pragma warning restore RS1025 // Configure generated code analysis { - base.Initialize(context); - + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); + context.EnableConcurrentExecution(); + context.RegisterSymbolAction(c => Analyze(c), SymbolKinds); context.RegisterSyntaxNodeAction(c => AnalyzeNode(c), SyntaxKind.LocalDeclarationStatement); } - public override void Analyze(ISymbolAnalysisContext context) + public void Analyze(SymbolAnalysisContext context) { if (IsAzureCore(context.Symbol.ContainingAssembly)) { @@ -56,30 +52,25 @@ public override void Analyze(ISymbolAnalysisContext context) switch (context.Symbol) { case IParameterSymbol parameterSymbol: - CheckType(context, parameterSymbol.Type, parameterSymbol); + CheckType(parameterSymbol.Type, parameterSymbol, context.ReportDiagnostic); break; case IMethodSymbol methodSymbol: - CheckType(context, methodSymbol.ReturnType, methodSymbol); - - //foreach (var typeSymbol in methodSymbol.) + CheckType(methodSymbol.ReturnType, methodSymbol, context.ReportDiagnostic); break; case IEventSymbol eventSymbol: - CheckType(context, eventSymbol.Type, eventSymbol); + CheckType(eventSymbol.Type, eventSymbol, context.ReportDiagnostic); break; case IPropertySymbol propertySymbol: - CheckType(context, propertySymbol.Type, propertySymbol); + CheckType(propertySymbol.Type, propertySymbol, context.ReportDiagnostic); break; case IFieldSymbol fieldSymbol: - CheckType(context, fieldSymbol.Type, fieldSymbol); - break; - case ILocalSymbol localSymbol: - CheckType(context, localSymbol.Type, localSymbol); + CheckType(fieldSymbol.Type, fieldSymbol, context.ReportDiagnostic); break; case INamedTypeSymbol namedTypeSymbol: - CheckType(context, namedTypeSymbol.BaseType, namedTypeSymbol); + CheckType(namedTypeSymbol.BaseType, namedTypeSymbol, context.ReportDiagnostic); foreach (var iface in namedTypeSymbol.Interfaces) { - CheckType(context, iface, namedTypeSymbol); + CheckType(iface, namedTypeSymbol, context.ReportDiagnostic); } break; } @@ -96,38 +87,29 @@ public void AnalyzeNode(SyntaxNodeAnalysisContext context) { ITypeSymbol type = context.SemanticModel.GetTypeInfo(declaration.Declaration.Type).Type; - if (type is INamedTypeSymbol namedTypeSymbol) - { - if (IsBannedType(namedTypeSymbol)) - { - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0020, context.Node.GetLocation(), BannedTypesMessageArgs)); - } - } + CheckType(type, type, context.ReportDiagnostic, context.Node.GetLocation()); } } - private static void CheckType(ISymbolAnalysisContext context, ITypeSymbol type, ISymbol symbol) + private static Diagnostic CheckType(ITypeSymbol type, ISymbol symbol, Action reportDiagnostic, Location location = default) { if (type is INamedTypeSymbol namedTypeSymbol) { if (IsBannedType(namedTypeSymbol)) { - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0020, symbol.Locations.First(), BannedTypesMessageArgs), symbol); + reportDiagnostic(Diagnostic.Create(Descriptors.AZC0020, location ?? symbol.Locations.First(), BannedTypesMessageArgs)); } if (namedTypeSymbol.IsGenericType) { foreach (var typeArgument in namedTypeSymbol.TypeArguments) { - CheckType(context, typeArgument, symbol); + CheckType(typeArgument, symbol, reportDiagnostic); } } } - } - private static bool IsBannedType(INamedTypeSymbol namedTypeSymbol) - { - return BannedTypes.Contains($"{namedTypeSymbol.ContainingNamespace}.{namedTypeSymbol.Name}"); + return null; } private static bool IsAzureCore(IAssemblySymbol assembly) @@ -136,5 +118,10 @@ private static bool IsAzureCore(IAssemblySymbol assembly) assembly.Name.Equals("Azure.Core") || assembly.Name.Equals("Azure.Core.Experimental"); } + + private static bool IsBannedType(INamedTypeSymbol namedTypeSymbol) + { + return BannedTypes.Contains($"{namedTypeSymbol.ContainingNamespace}.{namedTypeSymbol.Name}"); + } } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs index 4062c38c03b..c3030310f61 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs @@ -16,7 +16,7 @@ public abstract class SymbolAnalyzerBase : DiagnosticAnalyzer protected INamedTypeSymbol ClientOptionsType { get; private set; } - public override void Initialize(AnalysisContext context) + public sealed override void Initialize(AnalysisContext context) { context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); context.EnableConcurrentExecution(); @@ -49,9 +49,9 @@ protected bool IsClientOptionsType(ITypeSymbol typeSymbol) { return false; } - + ITypeSymbol baseType = typeSymbol.BaseType; - while (baseType != null) + while (baseType != null) { if (SymbolEqualityComparer.Default.Equals(baseType, ClientOptionsType)) { @@ -81,4 +81,4 @@ public void ReportDiagnostic(Diagnostic diagnostic, ISymbol symbol) } } } -} +} \ No newline at end of file