Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
annelo-msft committed Aug 11, 2023
1 parent 1453cbe commit 1b7eb5a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
Expand All @@ -9,7 +10,7 @@
namespace Azure.ClientSdk.Analyzers
{
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class BannedTypesAnalyzer : SymbolAnalyzerBase
public sealed class BannedTypesAnalyzer : DiagnosticAnalyzer
{
private static HashSet<string> BannedTypes = new HashSet<string>()
{
Expand All @@ -23,30 +24,25 @@ public sealed class BannedTypesAnalyzer : SymbolAnalyzerBase

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Descriptors.AZC0020);

public override SymbolKind[] SymbolKinds { get; } = new[]
public SymbolKind[] SymbolKinds { get; } = new[]
{
SymbolKind.Event,
SymbolKind.Field,
SymbolKind.Local,
SymbolKind.Method,
SymbolKind.NamedType,
SymbolKind.Parameter,
SymbolKind.Property,
};

// Note: suppressing warnings because they are handled in base.Initialize().
#pragma warning disable RS1025 // Configure generated code analysis
#pragma warning disable RS1026 // Enable concurrent execution
public override void Initialize(AnalysisContext context)
#pragma warning restore RS1026 // Enable concurrent execution
#pragma warning restore RS1025 // Configure generated code analysis
{
base.Initialize(context);

context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
context.EnableConcurrentExecution();
context.RegisterSymbolAction(c => Analyze(c), SymbolKinds);
context.RegisterSyntaxNodeAction(c => AnalyzeNode(c), SyntaxKind.LocalDeclarationStatement);
}

public override void Analyze(ISymbolAnalysisContext context)
public void Analyze(SymbolAnalysisContext context)
{
if (IsAzureCore(context.Symbol.ContainingAssembly))
{
Expand All @@ -56,30 +52,25 @@ public override void Analyze(ISymbolAnalysisContext context)
switch (context.Symbol)
{
case IParameterSymbol parameterSymbol:
CheckType(context, parameterSymbol.Type, parameterSymbol);
CheckType(parameterSymbol.Type, parameterSymbol, context.ReportDiagnostic);
break;
case IMethodSymbol methodSymbol:
CheckType(context, methodSymbol.ReturnType, methodSymbol);

//foreach (var typeSymbol in methodSymbol.)
CheckType(methodSymbol.ReturnType, methodSymbol, context.ReportDiagnostic);
break;
case IEventSymbol eventSymbol:
CheckType(context, eventSymbol.Type, eventSymbol);
CheckType(eventSymbol.Type, eventSymbol, context.ReportDiagnostic);
break;
case IPropertySymbol propertySymbol:
CheckType(context, propertySymbol.Type, propertySymbol);
CheckType(propertySymbol.Type, propertySymbol, context.ReportDiagnostic);
break;
case IFieldSymbol fieldSymbol:
CheckType(context, fieldSymbol.Type, fieldSymbol);
break;
case ILocalSymbol localSymbol:
CheckType(context, localSymbol.Type, localSymbol);
CheckType(fieldSymbol.Type, fieldSymbol, context.ReportDiagnostic);
break;
case INamedTypeSymbol namedTypeSymbol:
CheckType(context, namedTypeSymbol.BaseType, namedTypeSymbol);
CheckType(namedTypeSymbol.BaseType, namedTypeSymbol, context.ReportDiagnostic);
foreach (var iface in namedTypeSymbol.Interfaces)
{
CheckType(context, iface, namedTypeSymbol);
CheckType(iface, namedTypeSymbol, context.ReportDiagnostic);
}
break;
}
Expand All @@ -96,38 +87,29 @@ public void AnalyzeNode(SyntaxNodeAnalysisContext context)
{
ITypeSymbol type = context.SemanticModel.GetTypeInfo(declaration.Declaration.Type).Type;

if (type is INamedTypeSymbol namedTypeSymbol)
{
if (IsBannedType(namedTypeSymbol))
{
context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0020, context.Node.GetLocation(), BannedTypesMessageArgs));
}
}
CheckType(type, type, context.ReportDiagnostic, context.Node.GetLocation());
}
}

private static void CheckType(ISymbolAnalysisContext context, ITypeSymbol type, ISymbol symbol)
private static Diagnostic CheckType(ITypeSymbol type, ISymbol symbol, Action<Diagnostic> reportDiagnostic, Location location = default)
{
if (type is INamedTypeSymbol namedTypeSymbol)
{
if (IsBannedType(namedTypeSymbol))
{
context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0020, symbol.Locations.First(), BannedTypesMessageArgs), symbol);
reportDiagnostic(Diagnostic.Create(Descriptors.AZC0020, location ?? symbol.Locations.First(), BannedTypesMessageArgs));
}

if (namedTypeSymbol.IsGenericType)
{
foreach (var typeArgument in namedTypeSymbol.TypeArguments)
{
CheckType(context, typeArgument, symbol);
CheckType(typeArgument, symbol, reportDiagnostic);
}
}
}
}

private static bool IsBannedType(INamedTypeSymbol namedTypeSymbol)
{
return BannedTypes.Contains($"{namedTypeSymbol.ContainingNamespace}.{namedTypeSymbol.Name}");
return null;
}

private static bool IsAzureCore(IAssemblySymbol assembly)
Expand All @@ -136,5 +118,10 @@ private static bool IsAzureCore(IAssemblySymbol assembly)
assembly.Name.Equals("Azure.Core") ||
assembly.Name.Equals("Azure.Core.Experimental");
}

private static bool IsBannedType(INamedTypeSymbol namedTypeSymbol)
{
return BannedTypes.Contains($"{namedTypeSymbol.ContainingNamespace}.{namedTypeSymbol.Name}");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public abstract class SymbolAnalyzerBase : DiagnosticAnalyzer

protected INamedTypeSymbol ClientOptionsType { get; private set; }

public override void Initialize(AnalysisContext context)
public sealed override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
context.EnableConcurrentExecution();
Expand Down Expand Up @@ -49,9 +49,9 @@ protected bool IsClientOptionsType(ITypeSymbol typeSymbol)
{
return false;
}

ITypeSymbol baseType = typeSymbol.BaseType;
while (baseType != null)
while (baseType != null)
{
if (SymbolEqualityComparer.Default.Equals(baseType, ClientOptionsType))
{
Expand Down Expand Up @@ -81,4 +81,4 @@ public void ReportDiagnostic(Diagnostic diagnostic, ISymbol symbol)
}
}
}
}
}

0 comments on commit 1b7eb5a

Please sign in to comment.