diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs index 582a4f00c6c..b9a6b5a2060 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs @@ -10,34 +10,36 @@ namespace Azure.ClientSdk.Analyzers.Tests public class AZC0002Tests { [Fact] - public async Task AZC0002ProducedForMethodsWithoutCancellationToken() + public async Task AZC0002ProducedForMethodsWithoutCancellationTokenOrRequestContext() { const string code = @" +using Azure; using System.Threading.Tasks; namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}() + public virtual Task {|AZC0002:GetAsync|}() { return null; } - public virtual void {|AZC0002:Get|}() + public virtual Response {|AZC0002:Get|}() { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); } [Fact] - public async Task AZC0002ProducedForMethodsWithNonOptionalCancellationToken() + public async Task AZC0002ProducedForMethodsWithWrongNameCancellationToken() { const string code = @" +using Azure; using System.Threading; using System.Threading.Tasks; @@ -45,25 +47,27 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken) + public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellation = default) { return null; } - public virtual void {|AZC0002:Get|}(CancellationToken cancellationToken) + public virtual Response {|AZC0002:Get|}(CancellationToken cancellation = default) { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); } [Fact] - public async Task AZC0002ProducedForMethodsWithWrongNameParameter() + public async Task AZC0002ProducedForMethodsWithWrongNameRequestContext() { const string code = @" +using Azure; +using Azure.Core; using System.Threading; using System.Threading.Tasks; @@ -71,23 +75,23 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellation = default) + public virtual Task {|AZC0002:GetAsync|}(RequestContext cancellation = default) { return null; } - public virtual void {|AZC0002:Get|}(CancellationToken cancellation = default) + public virtual Response {|AZC0002:Get|}(RequestContext cancellation = default) { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); - } - + } + [Fact] - public async Task AZC0002ProducedForMethodsWhereRequestContextIsNotLast() + public async Task AZC0002ProducedForMethodsWithNonOptionalCancellationToken() { const string code = @" using Azure; @@ -98,25 +102,26 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(RequestContext context = default, string text = default) + public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken) { return null; } - public virtual void {|AZC0002:Get|}(RequestContext context = default, string text = default) + public virtual Response {|AZC0002:Get|}(CancellationToken cancellationToken) { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); } [Fact] - public async Task AZC0002DoesntFireIfThereIsAnOverloadWithCancellationToken() + public async Task AZC0002ProducedForMethodsWhereRequestContextIsNotLast() { const string code = @" +using Azure; using System.Threading; using System.Threading.Tasks; @@ -124,32 +129,23 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task GetAsync(string s) + public virtual Task {|AZC0002:GetAsync|}(RequestContext context = default, string text = default) { return null; } - public virtual void Get(string s) - { - } - - public virtual Task GetAsync(string s, CancellationToken cancellationToken) + public virtual Response {|AZC0002:Get|}(RequestContext context = default, string text = default) { return null; } - - public virtual void Get(string s, CancellationToken cancellationToken) - { - } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); - } - + } + [Fact] - public async Task AZC0002DoesntFireIfThereIsAnOverloadWithRequestContext() + public async Task AZC0002ProducedForMethodsWhereCancellationTokenIsNotLast() { const string code = @" using Azure; @@ -160,34 +156,53 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task GetAsync(string s) + public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken = default, string text = default) { return null; } - public virtual void Get(string s) + public virtual Response {|AZC0002:Get|}(CancellationToken cancellationToken = default, string text = default) { + return null; } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0002NotProducedForMethodsWithCancellationToken() + { + const string code = @" +using Azure; +using System.Threading; +using System.Threading.Tasks; - public virtual Task GetAsync(string s, RequestContext context = default) +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetAsync(CancellationToken cancellationToken = default) { return null; } - public virtual void Get(string s, RequestContext context = default) + public virtual Response Get(CancellationToken cancellationToken = default) { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); } [Fact] - public async Task AZC0002ProducedWhenCancellationTokenOverloadsDontMatch() + public async Task AZC0002NotProducedForMethodsWithRequestContextAndCancellationToken() { const string code = @" +using Azure; using System.Threading; using System.Threading.Tasks; @@ -195,34 +210,36 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(string s) + public virtual Task Get1Async(string s, CancellationToken cancellationToken = default) { return null; } - public virtual void {|AZC0002:Get|}(string s) + public virtual Response Get1(string s, CancellationToken cancellationToken = default) { + return null; } - public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken) + public virtual Task Get1Async(string s, RequestContext context) { return null; } - public virtual void {|AZC0002:Get|}(CancellationToken cancellationToken) + public virtual Response Get1(string s, RequestContext context) { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); } [Fact] - public async Task AZC0002NotProducedForMethodsWithCancellationToken() + public async Task AZC0002NotProducedForMethodsWithRequestContext() { const string code = @" +using Azure; using System.Threading; using System.Threading.Tasks; @@ -230,23 +247,23 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task GetAsync(CancellationToken cancellationToken = default) + public virtual Task Get2Async(RequestContext context) { return null; } - public virtual void Get(CancellationToken cancellationToken = default) + public virtual Response Get2(RequestContext context) { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); - } - + } + [Fact] - public async Task AZC0002NotProducedForMethodsWithRequestContext() + public async Task AZC0002NotProducedIfThereIsAnOverloadWithCancellationToken() { const string code = @" using Azure; @@ -257,18 +274,28 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task GetAsync(RequestContext context = default) + public virtual Task GetAsync(string s) { return null; } - public virtual void Get(RequestContext context = default) + public virtual Response Get(string s) { + return null; + } + + public virtual Task GetAsync(string s, CancellationToken cancellationToken) + { + return null; + } + + public virtual Response Get(string s, CancellationToken cancellationToken) + { + return null; } } }"; await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") .RunAsync(); } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs index ab20919b072..d771f2404cd 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs @@ -13,6 +13,13 @@ public class ClientMethodsAnalyzer : ClientAnalyzerBase { private const string AsyncSuffix = "Async"; + private const string PageableTypeName = "Pageable"; + private const string AsyncPageableTypeName = "AsyncPageable"; + private const string ResponseTypeName = "Response"; + private const string NullableResponseTypeName = "NullableResponse"; + private const string OperationTypeName = "Operation"; + private const string TaskTypeName = "Task"; + public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(new[] { Descriptors.AZC0002, @@ -21,96 +28,109 @@ public class ClientMethodsAnalyzer : ClientAnalyzerBase Descriptors.AZC0015 }); - private static void CheckClientMethod(ISymbolAnalysisContext context, IMethodSymbol member) + private static bool IsRequestContext(IParameterSymbol parameterSymbol) { - static bool SupportsCancellationsParameter(IParameterSymbol parameterSymbol) - { - return (parameterSymbol.Name == "cancellationToken" && parameterSymbol.Type.Name == "CancellationToken") || - (parameterSymbol.Name == "context" && parameterSymbol.Type.Name == "RequestContext"); - } + return parameterSymbol.Name == "context" && parameterSymbol.Type.Name == "RequestContext"; + } - CheckClientMethodReturnType(context, member); + private static bool IsCancellationToken(IParameterSymbol parameterSymbol) + { + return parameterSymbol.Name == "cancellationToken" && parameterSymbol.Type.Name == "CancellationToken"; + } - if (!member.IsVirtual && !member.IsOverride) - { - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0003, member.Locations.First()), member); - } + private static bool IsCancellationOrRequestContext(IParameterSymbol parameterSymbol) + { + return IsCancellationToken(parameterSymbol) || IsRequestContext(parameterSymbol); + } + private static void CheckIsLastArgumentCancellationTokenOrRequestContext(ISymbolAnalysisContext context, IMethodSymbol member) + { var lastArgument = member.Parameters.LastOrDefault(); - var supportsCancellations = lastArgument != null && SupportsCancellationsParameter(lastArgument); + var isLastArgumentCancellationOrRequestContext = lastArgument != null && IsCancellationOrRequestContext(lastArgument); - if (!supportsCancellations) + if (!isLastArgumentCancellationOrRequestContext) { var overloadSupportsCancellations = FindMethod( member.ContainingType.GetMembers(member.Name).OfType(), member.TypeParameters, member.Parameters, - p => SupportsCancellationsParameter(p)); + p => IsCancellationToken(p)); - if (overloadSupportsCancellations != null) + if (overloadSupportsCancellations == null) { - // Skip methods that have overloads with cancellation tokens - return; + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); } - - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); } - else if (!lastArgument.IsOptional) + else if (IsCancellationToken(lastArgument) && !lastArgument.IsOptional) { var overloadWithCancellationToken = FindMethod( member.ContainingType.GetMembers(member.Name).OfType(), member.TypeParameters, member.Parameters.RemoveAt(member.Parameters.Length - 1)); - if (overloadWithCancellationToken != null) + if (overloadWithCancellationToken == null) { - // Skip methods that have non-optional cancellation token if overload exists without one - return; + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); } - - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); } } - private static void CheckClientMethodReturnType(ISymbolAnalysisContext context, IMethodSymbol method) + private static void CheckClientMethod(ISymbolAnalysisContext context, IMethodSymbol member) { - bool IsOrImplements(ITypeSymbol typeSymbol, string typeName) + CheckClientMethodReturnType(context, member); + + if (!member.IsVirtual && !member.IsOverride) { - if (typeSymbol.Name == typeName) - { - return true; - } + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0003, member.Locations.First()), member); + } + } - if (typeSymbol.BaseType != null) - { - return IsOrImplements(typeSymbol.BaseType, typeName); - } + private static bool IsOrImplements(ITypeSymbol typeSymbol, string typeName) + { + if (typeSymbol.Name == typeName) + { + return true; + } - return false; + if (typeSymbol.BaseType != null) + { + return IsOrImplements(typeSymbol.BaseType, typeName); } + return false; + } + + private static void CheckClientMethodReturnType(ISymbolAnalysisContext context, IMethodSymbol method) + { + IsClientMethodReturnType(context, method, true); + } + + private static bool IsClientMethodReturnType(ISymbolAnalysisContext context, IMethodSymbol method, bool throwError = false) + { ITypeSymbol originalType = method.ReturnType; ITypeSymbol unwrappedType = method.ReturnType; if (method.ReturnType is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType && - namedTypeSymbol.Name == "Task") + namedTypeSymbol.Name == TaskTypeName) { unwrappedType = namedTypeSymbol.TypeArguments.Single(); } - if (IsOrImplements(unwrappedType, "Response") || - IsOrImplements(unwrappedType, "NullableResponse") || - IsOrImplements(unwrappedType, "Operation") || - IsOrImplements(originalType, "Pageable") || - IsOrImplements(originalType, "AsyncPageable") || - originalType.Name.EndsWith(ClientSuffix)) + if (IsOrImplements(unwrappedType, ResponseTypeName) || + IsOrImplements(unwrappedType, NullableResponseTypeName) || + IsOrImplements(unwrappedType, OperationTypeName) || + IsOrImplements(originalType, PageableTypeName) || + IsOrImplements(originalType, AsyncPageableTypeName)) { - return; + return true; } - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0015, method.Locations.FirstOrDefault(), originalType.ToDisplayString()), method); - + if (throwError) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0015, method.Locations.FirstOrDefault(), originalType.ToDisplayString()), method); + } + return false; } public override void AnalyzeCore(ISymbolAnalysisContext context) @@ -118,7 +138,13 @@ public override void AnalyzeCore(ISymbolAnalysisContext context) INamedTypeSymbol type = (INamedTypeSymbol)context.Symbol; foreach (var member in type.GetMembers()) { - if (member is IMethodSymbol methodSymbol && methodSymbol.Name.EndsWith(AsyncSuffix) && member.DeclaredAccessibility == Accessibility.Public) + var methodSymbol = member as IMethodSymbol; + if (methodSymbol == null || methodSymbol.DeclaredAccessibility != Accessibility.Public) + { + continue; + } + + if (methodSymbol.Name.EndsWith(AsyncSuffix)) { CheckClientMethod(context, methodSymbol); @@ -135,6 +161,11 @@ public override void AnalyzeCore(ISymbolAnalysisContext context) CheckClientMethod(context, syncMember); } } + + if (IsClientMethodReturnType(context, methodSymbol, false)) + { + CheckIsLastArgumentCancellationTokenOrRequestContext(context, methodSymbol); + } } } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs index 86d9e4d8072..0b292863d5b 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs @@ -16,8 +16,8 @@ internal class Descriptors public static DiagnosticDescriptor AZC0002 = new DiagnosticDescriptor( nameof(AZC0002), - "DO ensure all service methods, both asynchronous and synchronous, take an optional CancellationToken parameter called cancellationToken.", - "Client method should have cancellationToken as the last optional parameter (both name and it being optional matters)", + "DO ensure all service methods, both asynchronous and synchronous, take an optional CancellationToken parameter called 'cancellationToken' or a RequestContext parameter called 'context'.", + "Client method should have an optional CancellationToken called cancellationToken (both name and it being optional matters) or a RequestContext called context as the last parameter.", "Usage", DiagnosticSeverity.Warning, isEnabledByDefault: true, description: null, "https://azure.github.io/azure-sdk/dotnet_introduction.html#dotnet-service-methods-cancellation" );