diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs index 3c5b56b99ed..8ebdcb03bfb 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs @@ -295,5 +295,43 @@ await Verifier.CreateAnalyzer(code) .WithDisabledDiagnostics("AZC0018") .RunAsync(); } + + [Fact] + public async Task AZC0002NotProducedIfThereIsAnOverloadWithCancellationToken() + { + const string code = @" +using Azure; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Response GetAsync(string s) + { + return null; + } + + public virtual Response Get(string s) + { + return null; + } + + public virtual Response GetAsync(string s, CancellationToken cancellationToken) + { + return null; + } + + public virtual Response Get(string s, CancellationToken cancellationToken) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0018Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0018Tests.cs index 956592f0350..b9c39094ae9 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0018Tests.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0018Tests.cs @@ -25,6 +25,16 @@ public class AssetConversion {} public class SomeClient { + public virtual Task> GetHeadAsBooleanAsync(string s, RequestContext context) + { + return null; + } + + public virtual Response GetHeadAsBoolean(string s, RequestContext context) + { + return null; + } + public virtual Task GetResponseAsync(string s, RequestContext context) { return null; diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs index d3be35e253c..2080173bdb1 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs @@ -21,6 +21,7 @@ public class ClientMethodsAnalyzer : ClientAnalyzerBase private const string NullableResponseTypeName = "NullableResponse"; private const string OperationTypeName = "Operation"; private const string TaskTypeName = "Task"; + private const string BooleanTypeName = "Boolean"; public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(new[] { @@ -45,7 +46,7 @@ private static bool IsRequestContext(IParameterSymbol parameterSymbol) private static bool IsCancellationToken(IParameterSymbol parameterSymbol) { - return parameterSymbol.Name == "cancellationToken" && parameterSymbol.Type.Name == "CancellationToken" && parameterSymbol.IsOptional; + return parameterSymbol.Name == "cancellationToken" && parameterSymbol.Type.Name == "CancellationToken"; } private static void CheckClientMethod(ISymbolAnalysisContext context, IMethodSymbol member) @@ -67,10 +68,31 @@ static bool IsCancellationOrRequestContext(IParameterSymbol parameterSymbol) if (!isCancellationOrRequestContext) { - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); + var overloadSupportsCancellations = FindMethod( + member.ContainingType.GetMembers(member.Name).OfType(), + member.TypeParameters, + member.Parameters, + p => IsCancellationToken(p)); + + if (overloadSupportsCancellations == null) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); + } } else if (IsCancellationToken(lastArgument)) { + if (!lastArgument.IsOptional) + { + var overloadWithCancellationToken = FindMethod( + member.ContainingType.GetMembers(member.Name).OfType(), + member.TypeParameters, + member.Parameters.RemoveAt(member.Parameters.Length - 1)); + + if (overloadWithCancellationToken == null) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); + } + } // A convenience method should not have RequestContent as parameter if (member.Parameters.FirstOrDefault(IsRequestContent) != null) { @@ -129,7 +151,7 @@ private static void CheckProtocolMethodParameters(ISymbolAnalysisContext context } } - // A protocol method should not have model as type. Accepted return type: Response, Task, Pageable, AsyncPageable, Operation, Task>, Operation, Task, Operation>, Task>> + // A protocol method should not have model as type. Accepted return type: Response, Task, Response, Task>, Pageable, AsyncPageable, Operation, Task>, Operation, Task, Operation>, Task>> private static void CheckProtocolMethodReturnType(ISymbolAnalysisContext context, IMethodSymbol method) { bool IsValidPageable(ITypeSymbol typeSymbol) @@ -168,7 +190,11 @@ bool IsValidPageable(ITypeSymbol typeSymbol) { if (unwrappedType is INamedTypeSymbol responseTypeSymbol && responseTypeSymbol.IsGenericType) { - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + var responseReturn = responseTypeSymbol.TypeArguments.Single(); + if (responseReturn.Name != BooleanTypeName) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + } } return; } @@ -264,32 +290,45 @@ private bool IsCheckExempt(ISymbolAnalysisContext context, IMethodSymbol method) public override void AnalyzeCore(ISymbolAnalysisContext context) { + void CheckSyncAsyncPair(INamedTypeSymbol type, IMethodSymbol method, string methodName) + { + var lastArgument = method.Parameters.LastOrDefault(); + IMethodSymbol methodSymbol = null; + if (lastArgument != null && IsRequestContext(lastArgument)) + { + methodSymbol = FindMethod(type.GetMembers(methodName).OfType(), method.TypeParameters, method.Parameters); + } + else if (lastArgument != null && IsCancellationToken(lastArgument)) + { + methodSymbol = FindMethod(type.GetMembers(methodName).OfType(), method.TypeParameters, method.Parameters.RemoveAt(method.Parameters.Length - 1), p => IsCancellationToken(p)); + } + else + { + return; + } + + if (methodSymbol == null) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0004, method.Locations.First()), method); + } + } + INamedTypeSymbol type = (INamedTypeSymbol)context.Symbol; foreach (var member in type.GetMembers()) { if (member is IMethodSymbol asyncMethodSymbol && !IsCheckExempt(context, asyncMethodSymbol) && asyncMethodSymbol.Name.EndsWith(AsyncSuffix)) { - var syncMemberName = member.Name.Substring(0, member.Name.Length - AsyncSuffix.Length); - var syncMember = FindMethod(type.GetMembers(syncMemberName).OfType(), asyncMethodSymbol.TypeParameters, asyncMethodSymbol.Parameters); - - if (syncMember == null) - { - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0004, member.Locations.First()), member); - } - CheckClientMethod(context, asyncMethodSymbol); + + var syncMemberName = member.Name.Substring(0, member.Name.Length - AsyncSuffix.Length); + CheckSyncAsyncPair(type, asyncMethodSymbol, syncMemberName); } else if (member is IMethodSymbol syncMethodSymbol && !IsCheckExempt(context, syncMethodSymbol) && !syncMethodSymbol.Name.EndsWith(AsyncSuffix)) { - var asyncMemberName = member.Name + AsyncSuffix; - var asyncMember = FindMethod(type.GetMembers(asyncMemberName).OfType(), syncMethodSymbol.TypeParameters, syncMethodSymbol.Parameters); - - if (asyncMember == null) - { - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0004, member.Locations.First()), member); - } - CheckClientMethod(context, syncMethodSymbol); + + var asyncMemberName = member.Name + AsyncSuffix; + CheckSyncAsyncPair(type, syncMethodSymbol, asyncMemberName); } } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs index 338e29b27b7..447615e75f0 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs @@ -17,7 +17,7 @@ internal class Descriptors public static DiagnosticDescriptor AZC0002 = new DiagnosticDescriptor( nameof(AZC0002), "DO ensure all service methods, both asynchronous and synchronous, take an optional CancellationToken parameter called cancellationToken or a RequestContext parameter called context.", - "Client method should have an optional CancellationToken (both name and it being optional matters) or a RequestContext as the last parameter.", + "Client method should have an optional CancellationToken called cancellationToken (both name and it being optional matters) or a RequestContext called context as the last parameter.", "Usage", DiagnosticSeverity.Warning, isEnabledByDefault: true, description: null, "https://azure.github.io/azure-sdk/dotnet_introduction.html#dotnet-service-methods-cancellation" );