diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs index 582a4f00c6c..b63e6dbe8f0 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0002Tests.cs @@ -10,7 +10,7 @@ namespace Azure.ClientSdk.Analyzers.Tests public class AZC0002Tests { [Fact] - public async Task AZC0002ProducedForMethodsWithoutCancellationToken() + public async Task AZC0002ProducedForMethodsWithoutCancellationTokenOrRequestContext() { const string code = @" using System.Threading.Tasks; @@ -35,7 +35,7 @@ await Verifier.CreateAnalyzer(code) } [Fact] - public async Task AZC0002ProducedForMethodsWithNonOptionalCancellationToken() + public async Task AZC0002ProducedForMethodsWithWrongNameCancellationToken() { const string code = @" using System.Threading; @@ -45,12 +45,12 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken) + public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellation = default) { return null; } - public virtual void {|AZC0002:Get|}(CancellationToken cancellationToken) + public virtual void {|AZC0002:Get|}(CancellationToken cancellation = default) { } } @@ -61,9 +61,10 @@ await Verifier.CreateAnalyzer(code) } [Fact] - public async Task AZC0002ProducedForMethodsWithWrongNameParameter() + public async Task AZC0002ProducedForMethodsWithWrongNameRequestContext() { const string code = @" +using Azure; using System.Threading; using System.Threading.Tasks; @@ -71,12 +72,12 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellation = default) + public virtual Task {|AZC0002:GetAsync|}(RequestContext cancellation = default) { return null; } - public virtual void {|AZC0002:Get|}(CancellationToken cancellation = default) + public virtual void {|AZC0002:Get|}(RequestContext cancellation = default) { } } @@ -85,12 +86,11 @@ await Verifier.CreateAnalyzer(code) .WithDisabledDiagnostics("AZC0015") .RunAsync(); } - + [Fact] - public async Task AZC0002ProducedForMethodsWhereRequestContextIsNotLast() + public async Task AZC0002ProducedForMethodsWithNonOptionalCancellationToken() { const string code = @" -using Azure; using System.Threading; using System.Threading.Tasks; @@ -98,12 +98,12 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(RequestContext context = default, string text = default) + public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken) { return null; } - public virtual void {|AZC0002:Get|}(RequestContext context = default, string text = default) + public virtual void {|AZC0002:Get|}(CancellationToken cancellationToken) { } } @@ -112,11 +112,12 @@ await Verifier.CreateAnalyzer(code) .WithDisabledDiagnostics("AZC0015") .RunAsync(); } - + [Fact] - public async Task AZC0002DoesntFireIfThereIsAnOverloadWithCancellationToken() + public async Task AZC0002ProducedForMethodsWhereRequestContextIsNotLast() { const string code = @" +using Azure; using System.Threading; using System.Threading.Tasks; @@ -124,21 +125,12 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task GetAsync(string s) - { - return null; - } - - public virtual void Get(string s) - { - } - - public virtual Task GetAsync(string s, CancellationToken cancellationToken) + public virtual Task {|AZC0002:GetAsync|}(RequestContext context = default, string text = default) { return null; } - public virtual void Get(string s, CancellationToken cancellationToken) + public virtual void {|AZC0002:Get|}(RequestContext context = default, string text = default) { } } @@ -146,10 +138,10 @@ public virtual void Get(string s, CancellationToken cancellationToken) await Verifier.CreateAnalyzer(code) .WithDisabledDiagnostics("AZC0015") .RunAsync(); - } - + } + [Fact] - public async Task AZC0002DoesntFireIfThereIsAnOverloadWithRequestContext() + public async Task AZC0002ProducedForMethodsWhereCancellationTokenIsNotLast() { const string code = @" using Azure; @@ -160,21 +152,38 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task GetAsync(string s) + public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken = default, string text = default) { return null; } - public virtual void Get(string s) + public virtual void {|AZC0002:Get|}(CancellationToken cancellationToken = default, string text = default) { } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } - public virtual Task GetAsync(string s, RequestContext context = default) + [Fact] + public async Task AZC0002NotProducedForMethodsWithCancellationToken() + { + const string code = @" +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetAsync(CancellationToken cancellationToken = default) { return null; } - public virtual void Get(string s, RequestContext context = default) + public virtual void Get(CancellationToken cancellationToken = default) { } } @@ -182,12 +191,13 @@ public virtual void Get(string s, RequestContext context = default) await Verifier.CreateAnalyzer(code) .WithDisabledDiagnostics("AZC0015") .RunAsync(); - } - + } + [Fact] - public async Task AZC0002ProducedWhenCancellationTokenOverloadsDontMatch() + public async Task AZC0002NotProducedForMethodsWithRequestContextAndCancellationToken() { const string code = @" +using Azure; using System.Threading; using System.Threading.Tasks; @@ -195,56 +205,50 @@ namespace RandomNamespace { public class SomeClient { - public virtual Task {|AZC0002:GetAsync|}(string s) + public virtual Task Get1Async(string s, CancellationToken cancellationToken = default) { return null; } - public virtual void {|AZC0002:Get|}(string s) + public virtual void Get1(string s, CancellationToken cancellationToken = default) { } - public virtual Task {|AZC0002:GetAsync|}(CancellationToken cancellationToken) + public virtual Task Get1Async(string s, RequestContext context) { return null; } - public virtual void {|AZC0002:Get|}(CancellationToken cancellationToken) + public virtual void Get1(string s, RequestContext context) { } - } -}"; - await Verifier.CreateAnalyzer(code) - .WithDisabledDiagnostics("AZC0015") - .RunAsync(); + + public virtual Task Get2Async(CancellationToken cancellationToken = default) + { + return null; } - [Fact] - public async Task AZC0002NotProducedForMethodsWithCancellationToken() + public virtual void Get2(CancellationToken cancellationToken = default) { - const string code = @" -using System.Threading; -using System.Threading.Tasks; + } -namespace RandomNamespace -{ - public class SomeClient - { - public virtual Task GetAsync(CancellationToken cancellationToken = default) + public virtual Task Get2Async(RequestContext context) { return null; } - public virtual void Get(CancellationToken cancellationToken = default) + public virtual void Get2(RequestContext context) { } } }"; await Verifier.CreateAnalyzer(code) .WithDisabledDiagnostics("AZC0015") + .WithDisabledDiagnostics("AZC0018") + .WithDisabledDiagnostics("AD0001") .RunAsync(); - } - + } + [Fact] public async Task AZC0002NotProducedForMethodsWithRequestContext() { @@ -265,10 +269,20 @@ public virtual Task GetAsync(RequestContext context = default) public virtual void Get(RequestContext context = default) { } + + public virtual Task Get2Async(RequestContext context) + { + return null; + } + + public virtual void Get2(RequestContext context) + { + } } }"; await Verifier.CreateAnalyzer(code) .WithDisabledDiagnostics("AZC0015") + .WithDisabledDiagnostics("AZC0018") .RunAsync(); } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs index ca9cddf34a5..3d35c522ea7 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - + using System.Threading.Tasks; using Xunit; using Verifier = Azure.ClientSdk.Analyzers.Tests.AzureAnalyzerVerifier; @@ -32,7 +32,29 @@ await Verifier.CreateAnalyzer(code) } [Fact] - public async Task AZC0004NotProducedForMethodsWithoutSyncAlternative() + public async Task AZC0004ProducedForMethodsWithoutAsyncAlternative() + { + const string code = @" +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0004:Get|}(CancellationToken cancellationToken = default) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } + + [Fact] + public async Task AZC0004NotProducedForMethodsWithCancellationToken() { const string code = @" using System.Threading; @@ -57,6 +79,116 @@ await Verifier.CreateAnalyzer(code) .RunAsync(); } + [Fact] + public async Task AZC0004NotProducedForMethodsWithOptionalRequestContext() + { + const string code = @" +using Azure; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetAsync(RequestContext context = null) + { + return null; + } + public virtual Task Get(RequestContext context = null) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .WithDisabledDiagnostics("AZC0018") + .RunAsync(); + } + + [Fact] + public async Task AZC0004NotProducedForMethodsWithRequiredRequestContext() + { + const string code = @" +using Azure; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetAsync(RequestContext context) + { + return null; + } + public virtual Response Get(RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } + + [Fact] + public async Task AZC0004ProducedForMethodsNotMatch() + { + const string code = @" +using Azure; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0004:GetAsync|}(string a, CancellationToken cancellationToken = default) + { + return null; + } + public virtual Task {|AZC0004:Get|}(string a, RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } + + [Fact] + public async Task AZC0004ProducedForMethodsWithNotMatchedRequestContext() + { + const string code = @" +using Azure; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0004:GetAsync|}(RequestContext context = null) + { + return null; + } + public virtual Response {|AZC0004:Get|}(RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .WithDisabledDiagnostics("AZC0018") + .RunAsync(); + } + [Fact] public async Task AZC0004ProducedForGenericMethodsWithSyncAlternative() { @@ -207,7 +339,7 @@ public class SomeClient return null; } - public virtual void Get(string sameNameDifferentType, CancellationToken cancellationToken = default) + public virtual void {|AZC0004:Get|}(string sameNameDifferentType, CancellationToken cancellationToken = default) { } } @@ -263,7 +395,7 @@ public class SomeClient return null; } - public virtual void Query(Expression> filter, CancellationToken cancellationToken = default) + public virtual void {|AZC0004:Query|}(Expression> filter, CancellationToken cancellationToken = default) { } } @@ -327,7 +459,7 @@ public class SomeClient } - public virtual void Append( + public virtual void {|AZC0004:Append|}( string[] arr, CancellationToken cancellationToken = default) { @@ -355,7 +487,7 @@ public class SomeClient return null; } - public virtual void Get(int differentName, CancellationToken cancellationToken = default) + public virtual void {|AZC0004:Get|}(int differentName, CancellationToken cancellationToken = default) { } } @@ -365,4 +497,4 @@ await Verifier.CreateAnalyzer(code) .RunAsync(); } } -} \ No newline at end of file +} diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0017Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0017Tests.cs new file mode 100644 index 00000000000..41edbc10aee --- /dev/null +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0017Tests.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Threading.Tasks; +using Xunit; +using Verifier = Azure.ClientSdk.Analyzers.Tests.AzureAnalyzerVerifier; + +namespace Azure.ClientSdk.Analyzers.Tests +{ + public class AZC0017Tests + { + [Fact] + public async Task AZC0017ProducedForMethodsWithRequestContentParameter() + { + const string code = @" +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0017:GetAsync|}(RequestContent content, CancellationToken cancellationToken = default) + { + return null; + } + + public virtual void {|AZC0017:Get|}(RequestContent content, CancellationToken cancellationToken = default) + { + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } + + [Fact] + public async Task AZC0017NotProducedForMethodsWithCancellationToken() + { + const string code = @" +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetAsync(string s, CancellationToken cancellationToken = default) + { + return null; + } + + public virtual void Get(string s, CancellationToken cancellationToken = default) + { + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } + } +} diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0018Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0018Tests.cs new file mode 100644 index 00000000000..cc55370a9fc --- /dev/null +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0018Tests.cs @@ -0,0 +1,433 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Threading.Tasks; +using Xunit; +using Verifier = Azure.ClientSdk.Analyzers.Tests.AzureAnalyzerVerifier; + +namespace Azure.ClientSdk.Analyzers.Tests +{ + public class AZC0018Tests + { + [Fact] + public async Task AZC0018NotProducedForCorrectReturnType() + { + const string code = @" +using Azure; +using Azure.Core; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetResponseAsync(string s, RequestContext context) + { + return null; + } + + public virtual Response GetResponse(string s, RequestContext context) + { + return null; + } + + public virtual AsyncPageable GetPageableAsync(string s, RequestContext context) + { + return null; + } + + public virtual Pageable GetPageable(string s, RequestContext context) + { + return null; + } + + public virtual Task GetOperationAsync(string s, RequestContext context) + { + return null; + } + + public virtual Operation GetOperation(string s, RequestContext context) + { + return null; + } + + public virtual Task> GetOperationOfTAsync(string s, RequestContext context) + { + return null; + } + + public virtual Operation GetOperationOfT(string s, RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithGenericResponseOfPrimitive() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task> {|AZC0018:GetAsync|}(string s, RequestContext context) + { + return null; + } + + public virtual Response {|AZC0018:Get|}(string s, RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithGenericResponseOfModel() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class Model + { + string a; + } + public class SomeClient + { + public virtual Task> {|AZC0018:GetAsync|}(string s, RequestContext context) + { + return null; + } + + public virtual Response {|AZC0018:Get|}(string s, RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithPageableOfModel() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class Model + { + string a; + } + public class SomeClient + { + public virtual AsyncPageable {|AZC0018:GetAsync|}(string s, RequestContext context) + { + return null; + } + + public virtual Pageable {|AZC0018:Get|}(string s, RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithOperationOfModel() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class Model + { + string a; + } + public class SomeClient + { + public virtual Task> {|AZC0018:GetAsync|}(string s, RequestContext context) + { + return null; + } + + public virtual Operation {|AZC0018:Get|}(string s, RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithParameterModel() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; + +namespace RandomNamespace +{ + public struct Model + { + string a; + } + public class SomeClient + { + public virtual Task {|AZC0018:GetAsync|}(Model model, Azure.RequestContext context) + { + return null; + } + + public virtual Response {|AZC0018:Get|}(Model model, Azure.RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithNoRequestContentAndOptionalRequestContext() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0018:GetAsync|}(string a, Azure.RequestContext context = null) + { + return null; + } + + public virtual Response {|AZC0018:Get|}(string a, Azure.RequestContext context = null) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithRequiredRequestContentAndRequiredRequestContext() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; + +namespace Azure.Core +{ + internal static partial class Argument + { + public static void AssertNotNull(T value, string name) + { + if (value is null) + { + throw new System.ArgumentNullException(name); + } + } + } +} + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0018:GetAsync|}(RequestContent content, RequestContext context) + { + Argument.AssertNotNull(content, nameof(content)); + return null; + } + + public virtual Response {|AZC0018:Get|}(RequestContent content, RequestContext context) + { + Argument.AssertNotNull(content, nameof(content)); + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018ProducedForMethodsWithOptionalRequestContentAndOptionalRequestContext() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; + +namespace Azure.Core +{ + internal static partial class Argument + { + public static void AssertNotNull(T value, string name) + { + if (value is null) + { + throw new System.ArgumentNullException(name); + } + } + } +} + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0018:GetAsync|}(RequestContent content, RequestContext context = null) + { + return null; + } + + public virtual Response {|AZC0018:Get|}(RequestContent content, RequestContext context = null) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018NotProducedForMethodsWithOptionalRequestContentAndRequiredRequestContext() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; + +namespace Azure.Core +{ + internal static partial class Argument + { + public static void AssertNotNull(T value, string name) + { + if (value is null) + { + throw new System.ArgumentNullException(name); + } + } + } +} + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetAsync(RequestContent content, RequestContext context) + { + return null; + } + + public virtual Response Get(RequestContent content, RequestContext context) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + + [Fact] + public async Task AZC0018NotProducedForMethodsWithRequiredRequestContentAndOptionalRequestContext() + { + const string code = @" +using Azure; +using Azure.Core; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; + +namespace Azure.Core +{ + internal static partial class Argument + { + public static void AssertNotNull(T value, string name) + { + if (value is null) + { + throw new System.ArgumentNullException(name); + } + } + } +} + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task GetAsync(RequestContent content, RequestContext context = null) + { + Argument.AssertNotNull(content, nameof(content)); + return null; + } + + public virtual Response Get(RequestContent content, RequestContext context = null) + { + Argument.AssertNotNull(content, nameof(content)); + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .RunAsync(); + } + } +} diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientAnalyzerBase.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientAnalyzerBase.cs index 25b20d345fc..2e3480b9910 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientAnalyzerBase.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientAnalyzerBase.cs @@ -32,7 +32,7 @@ protected class ParameterEquivalenceComparer : IEqualityComparer methodSymbo public abstract void AnalyzeCore(ISymbolAnalysisContext context); } -} \ No newline at end of file +} diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs index ab20919b072..db3ce3d066f 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; namespace Azure.ClientSdk.Analyzers @@ -18,15 +20,107 @@ public class ClientMethodsAnalyzer : ClientAnalyzerBase Descriptors.AZC0002, Descriptors.AZC0003, Descriptors.AZC0004, - Descriptors.AZC0015 + Descriptors.AZC0015, + Descriptors.AZC0017, + Descriptors.AZC0018 }); + public override void Initialize(AnalysisContext context) + { + context.EnableConcurrentExecution(); + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); + base.Initialize(context); + context.RegisterCodeBlockAction(c => AnalyzeCodeBlock(c)); + } + + private void AnalyzeCodeBlock(CodeBlockAnalysisContext codeBlock) + { + var symbol = codeBlock.OwningSymbol; + if (symbol is IMethodSymbol methodSymbol) + { + var lastParameter = methodSymbol.Parameters.LastOrDefault(); + if (lastParameter != null && IsRequestContext(lastParameter)) + { + var requestContent = methodSymbol.Parameters.FirstOrDefault(p => IsRequestContent(p)); + if (requestContent != null) + { + bool isRequired = ContainsAssertNotNull(codeBlock, requestContent.Name); + if (isRequired && !lastParameter.IsOptional) + { + codeBlock.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, symbol.Locations.FirstOrDefault())); + } + if (!isRequired && lastParameter.IsOptional) + { + codeBlock.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, symbol.Locations.FirstOrDefault())); + } + } + } + } + } + + private static bool ContainsAssertNotNull(CodeBlockAnalysisContext codeBlock, string variableName) + { + // Check Argument.AssertNotNull(variableName, nameof(variableName)); + foreach (var invocation in codeBlock.CodeBlock.DescendantNodes().OfType()) + { + if (invocation.Expression is MemberAccessExpressionSyntax assertNotNull && assertNotNull.Name.Identifier.Text == "AssertNotNull") + { + if (assertNotNull.Expression is IdentifierNameSyntax identifierName && identifierName.Identifier.Text == "Argument" || + assertNotNull.Expression is MemberAccessExpressionSyntax memberAccessExpression && memberAccessExpression.Name.Identifier.Text == "Argument") + { + var argumentsList = invocation.ArgumentList.Arguments; + if (argumentsList.Count != 2) + { + continue; + } + if (argumentsList.First().Expression is IdentifierNameSyntax first) + { + if (first.Identifier.Text != variableName) + { + continue; + } + if (argumentsList.Last().Expression is InvocationExpressionSyntax second) + { + if (second.Expression is IdentifierNameSyntax nameof && nameof.Identifier.Text == "nameof") + { + if (second.ArgumentList.Arguments.Count != 1) + { + continue; + } + if (second.ArgumentList.Arguments.First().Expression is IdentifierNameSyntax contentName && contentName.Identifier.Text == variableName) + { + return true; + } + } + } + } + } + } + } + + return false; + } + + private static bool IsRequestContent(IParameterSymbol parameterSymbol) + { + return parameterSymbol.Type.Name == "RequestContent"; + } + + private static bool IsRequestContext(IParameterSymbol parameterSymbol) + { + return parameterSymbol.Name == "context" && parameterSymbol.Type.Name == "RequestContext"; + } + private static void CheckClientMethod(ISymbolAnalysisContext context, IMethodSymbol member) { - static bool SupportsCancellationsParameter(IParameterSymbol parameterSymbol) + static bool IsCancellationOrRequestContext(IParameterSymbol parameterSymbol) + { + return IsCancellationToken(parameterSymbol) || IsRequestContext(parameterSymbol); + } + + static bool IsCancellationToken(IParameterSymbol parameterSymbol) { - return (parameterSymbol.Name == "cancellationToken" && parameterSymbol.Type.Name == "CancellationToken") || - (parameterSymbol.Name == "context" && parameterSymbol.Type.Name == "RequestContext"); + return parameterSymbol.Name == "cancellationToken" && parameterSymbol.Type.Name == "CancellationToken" && parameterSymbol.IsOptional; } CheckClientMethodReturnType(context, member); @@ -37,39 +131,123 @@ static bool SupportsCancellationsParameter(IParameterSymbol parameterSymbol) } var lastArgument = member.Parameters.LastOrDefault(); - var supportsCancellations = lastArgument != null && SupportsCancellationsParameter(lastArgument); + var isCancellationOrRequestContext = lastArgument != null && IsCancellationOrRequestContext(lastArgument); - if (!supportsCancellations) + if (!isCancellationOrRequestContext) { - var overloadSupportsCancellations = FindMethod( - member.ContainingType.GetMembers(member.Name).OfType(), - member.TypeParameters, - member.Parameters, - p => SupportsCancellationsParameter(p)); - - if (overloadSupportsCancellations != null) + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); + } + else if (IsCancellationToken(lastArgument)) + { + // A convenience method should not have RequestContent as parameter + var requestContent = member.Parameters.FirstOrDefault(parameter => parameter.Type.Name == "RequestContent"); + if (requestContent != null) { - // Skip methods that have overloads with cancellation tokens - return; + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0017, member.Locations.FirstOrDefault()), member); } + } + else if (IsRequestContext(lastArgument)) + { + CheckProtocolMethodReturnType(context, member); + CheckProtocolMethodParameters(context, member); + } + } - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); + private static string GetFullNamespaceName(IParameterSymbol parameter) + { + var currentNamespace = parameter.Type.ContainingNamespace; + string currentName = currentNamespace.Name; + string fullNamespace = ""; + while (!string.IsNullOrEmpty(currentName)) + { + fullNamespace = fullNamespace == "" ? currentName : $"{currentName}.{fullNamespace}"; + currentNamespace = currentNamespace.ContainingNamespace; + currentName = currentNamespace.Name; + } + return fullNamespace; + } + + // A protocol method should not have model as parameter. If it has ambiguity with convenience method, it should have required RequestContext. + // Ambiguity: no RequestContent or has optional RequestContent. + // No ambiguity: has required RequestContent. + private static void CheckProtocolMethodParameters(ISymbolAnalysisContext context, IMethodSymbol method) + { + var containsModel = method.Parameters.Any(p => + { + var fullNamespace = GetFullNamespaceName(p); + return !fullNamespace.StartsWith("System") && !fullNamespace.StartsWith("Azure"); + }); + + if (containsModel) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + return; } - else if (!lastArgument.IsOptional) + + var requestContent = method.Parameters.FirstOrDefault(p => p.Type.Name == "RequestContent"); + if (requestContent == null) { - var overloadWithCancellationToken = FindMethod( - member.ContainingType.GetMembers(member.Name).OfType(), - member.TypeParameters, - member.Parameters.RemoveAt(member.Parameters.Length - 1)); + if (method.Parameters.LastOrDefault().IsOptional) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + } + } + // Optional RequestContent or required RequestContent is checked in AnalyzeCodeBlock. + } - if (overloadWithCancellationToken != null) + // A protocol method should not have model as type. Accepted return type: Response, Task, Pageable, AsyncPageable, Operation, Task>, Operation, Task + private static void CheckProtocolMethodReturnType(ISymbolAnalysisContext context, IMethodSymbol method) + { + ITypeSymbol originalType = method.ReturnType; + ITypeSymbol unwrappedType = method.ReturnType; + + if (method.ReturnType is INamedTypeSymbol namedTypeSymbol && + namedTypeSymbol.IsGenericType && + namedTypeSymbol.Name == "Task") + { + unwrappedType = namedTypeSymbol.TypeArguments.Single(); + } + + if (unwrappedType.Name == "Response") + { + if (unwrappedType is INamedTypeSymbol responseTypeSymbol && responseTypeSymbol.IsGenericType) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + } + return; + } + else if (unwrappedType.Name == "Operation") + { + if (unwrappedType is INamedTypeSymbol operationTypeSymbol && operationTypeSymbol.IsGenericType) { - // Skip methods that have non-optional cancellation token if overload exists without one - return; + var operationReturn = operationTypeSymbol.TypeArguments.Single(); + if (operationReturn.Name != "BinaryData") + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + } } + return; + } + else if (originalType.Name == "Pageable" || originalType.Name == "AsyncPageable") + { + if (originalType is INamedTypeSymbol pageableTypeSymbol) + { + if (!pageableTypeSymbol.IsGenericType) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + } - context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0002, member.Locations.FirstOrDefault()), member); + var pageableReturn = pageableTypeSymbol.TypeArguments.Single(); + if (pageableReturn.Name != "BinaryData") + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); + } + } + + return; } + + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0018, method.Locations.FirstOrDefault()), method); } private static void CheckClientMethodReturnType(ISymbolAnalysisContext context, IMethodSymbol method) @@ -116,15 +294,16 @@ bool IsOrImplements(ITypeSymbol typeSymbol, string typeName) public override void AnalyzeCore(ISymbolAnalysisContext context) { INamedTypeSymbol type = (INamedTypeSymbol)context.Symbol; + List visitedSyncMember = new List(); foreach (var member in type.GetMembers()) { - if (member is IMethodSymbol methodSymbol && methodSymbol.Name.EndsWith(AsyncSuffix) && member.DeclaredAccessibility == Accessibility.Public) + if (member is IMethodSymbol asyncMethodSymbol && asyncMethodSymbol.Name.EndsWith(AsyncSuffix) && member.DeclaredAccessibility == Accessibility.Public) { - CheckClientMethod(context, methodSymbol); + CheckClientMethod(context, asyncMethodSymbol); var syncMemberName = member.Name.Substring(0, member.Name.Length - AsyncSuffix.Length); - var syncMember = FindMethod(type.GetMembers(syncMemberName).OfType(), methodSymbol.TypeParameters, methodSymbol.Parameters); + var syncMember = FindMethod(type.GetMembers(syncMemberName).OfType(), asyncMethodSymbol.TypeParameters, asyncMethodSymbol.Parameters); if (syncMember == null) { @@ -132,9 +311,19 @@ public override void AnalyzeCore(ISymbolAnalysisContext context) } else { + visitedSyncMember.Add(syncMember); CheckClientMethod(context, syncMember); } } + else if (member is IMethodSymbol syncMethodSymbol && !member.IsImplicitlyDeclared && !syncMethodSymbol.Name.EndsWith(AsyncSuffix) && member.DeclaredAccessibility == Accessibility.Public && !visitedSyncMember.Contains(syncMethodSymbol)) + { + var asyncMemberName = member.Name + AsyncSuffix; + var asyncMember = FindMethod(type.GetMembers(asyncMemberName).OfType(), syncMethodSymbol.TypeParameters, syncMethodSymbol.Parameters); + if (asyncMember == null) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0004, member.Locations.First()), member); + } + } } } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs index ebcd98325ba..42161acffca 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/Descriptors.cs @@ -16,8 +16,8 @@ internal class Descriptors public static DiagnosticDescriptor AZC0002 = new DiagnosticDescriptor( nameof(AZC0002), - "DO ensure all service methods, both asynchronous and synchronous, take an optional CancellationToken parameter called cancellationToken.", - "Client method should have cancellationToken as the last optional parameter (both name and it being optional matters)", + "DO ensure all service methods, both asynchronous and synchronous, take an optional CancellationToken parameter called cancellationToken or a RequestContext parameter called context.", + "Client method should have an optional CancellationToken (both name and it being optional matters) or a RequestContext as the last parameter.", "Usage", DiagnosticSeverity.Warning, isEnabledByDefault: true, description: null, "https://azure.github.io/azure-sdk/dotnet_introduction.html#dotnet-service-methods-cancellation" ); @@ -110,7 +110,19 @@ internal class Descriptors "Invalid ServiceVersion member name.", "All parts of ServiceVersion members' names must begin with a number or uppercase letter and cannot have consecutive underscores.", "Usage", - DiagnosticSeverity.Warning, true); + DiagnosticSeverity.Warning, true); + + public static DiagnosticDescriptor AZC0017 = new DiagnosticDescriptor( + nameof(AZC0017), + "Do ensure convenience method not take RequestContent as parameter type.", + "Convenience method shouldn't have prameters with type RequestContent.", + "Usage", DiagnosticSeverity.Warning, isEnabledByDefault: true, description: null); + + public static DiagnosticDescriptor AZC0018 = new DiagnosticDescriptor( + nameof(AZC0018), + "Do ensure protocol method take a RequestContext parameter called context and not take models as parameter type or return type.", + "Protocol method should have requestContext as the last parameter and don't have model as parameter type or return type.", + "Usage", DiagnosticSeverity.Warning, isEnabledByDefault: true, description: null); #endregion #region General diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs index c3030310f61..3e5f7880953 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/SymbolAnalyzerBase.cs @@ -16,7 +16,7 @@ public abstract class SymbolAnalyzerBase : DiagnosticAnalyzer protected INamedTypeSymbol ClientOptionsType { get; private set; } - public sealed override void Initialize(AnalysisContext context) + public override void Initialize(AnalysisContext context) { context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); context.EnableConcurrentExecution();