diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs index d048a9e7f62..f95409213b4 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers.Tests/AZC0004Tests.cs @@ -396,5 +396,60 @@ public virtual Response Get(CancellationToken cancellationToken) await Verifier.CreateAnalyzer(code) .RunAsync(); } + + [Fact] + public async Task AZC0004ProducedForMethodsWithMismatchedReturnTypes() + { + const string code = @" +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0004:GetAsync|}(CancellationToken cancellationToken = default) + { + return null; + } + + public virtual int Get(CancellationToken cancellationToken = default) + { + return 0; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } + + [Fact] + public async Task AZC0004ProducedForMethotsWithMismatchedParameters() + { + const string code = @" +using System.Threading; +using System.Threading.Tasks; + +namespace RandomNamespace +{ + public class SomeClient + { + public virtual Task {|AZC0004:GetAsync|}(CancellationToken cancellationToken = default) + { + return null; + } + + public virtual string Get(string foo, CancellationToken cancellationToken = default) + { + return null; + } + } +}"; + await Verifier.CreateAnalyzer(code) + .WithDisabledDiagnostics("AZC0015") + .RunAsync(); + } + } } diff --git a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs index 8f51a40f6af..78e2674375c 100644 --- a/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs +++ b/src/dotnet/Azure.ClientSdk.Analyzers/Azure.ClientSdk.Analyzers/ClientMethodsAnalyzer.cs @@ -279,6 +279,92 @@ private static bool IsClientMethodReturnType(ISymbolAnalysisContext context, IMe return false; } + private static bool CompareReturnTypesRecursively(ITypeSymbol asyncType, ITypeSymbol syncType) + { + // Unwrap Task for async methods + if (asyncType is INamedTypeSymbol asyncNamedType && asyncNamedType.Name == TaskTypeName) + { + if (asyncNamedType.IsGenericType) + { + asyncType = asyncNamedType.TypeArguments.Single(); + } + else + { + // async returns Task, sync should return void or non-generic Task + if (syncType.SpecialType == SpecialType.System_Void || + (syncType is INamedTypeSymbol syncNamedType && syncNamedType.Name == TaskTypeName && !syncNamedType.IsGenericType)) + { + return true; + } + return false; + } + } + + // Map AsyncPageable to Pageable for easy comparison since they are equivalent for these purposes + if (asyncType is INamedTypeSymbol asyncTypeSymbol && asyncTypeSymbol.Name == AsyncPageableTypeName && asyncTypeSymbol.IsGenericType) + { + asyncType = asyncTypeSymbol.ContainingNamespace.GetTypeMembers(PageableTypeName).FirstOrDefault()?.Construct(asyncTypeSymbol.TypeArguments.ToArray()); + } + + // Compare directly if sync method return type is not a named type symbol + if (syncType is not INamedTypeSymbol syncTypeSymbol) + { + return SymbolEqualityComparer.Default.Equals(asyncType, syncType); + } + + // Compare type names and namespaces + if (asyncType is INamedTypeSymbol asyncNamedTypeSymbol && syncType is INamedTypeSymbol syncNamedTypeSymbol) + { + if (asyncNamedTypeSymbol.Name != syncNamedTypeSymbol.Name || + asyncNamedTypeSymbol.ContainingNamespace.ToDisplayString() != syncNamedTypeSymbol.ContainingNamespace.ToDisplayString()) + { + return false; + } + + // Compare nested types recursively + if (asyncNamedTypeSymbol.IsGenericType && syncNamedTypeSymbol.IsGenericType) + { + var asyncTypeArguments = asyncNamedTypeSymbol.TypeArguments; + var syncTypeArguments = syncNamedTypeSymbol.TypeArguments; + + if (asyncTypeArguments.Length != syncTypeArguments.Length) + { + return false; + } + + for (int i = 0; i < asyncTypeArguments.Length; i++) + { + if (!CompareReturnTypesRecursively(asyncTypeArguments[i], syncTypeArguments[i])) + { + return false; + } + } + return true; + } + else if ((!asyncNamedTypeSymbol.IsGenericType) && (!syncNamedTypeSymbol.IsGenericType)) + { + return true; + } + else + { + // One is a generic type and the other is not + return false; + } + } + + return SymbolEqualityComparer.Default.Equals(asyncType, syncType); + } + + private static bool DoReturnTypesMatch(ITypeSymbol asyncReturnType, ITypeSymbol syncReturnType) + { + if (asyncReturnType == null || syncReturnType == null) + { + return false; + } + + return CompareReturnTypesRecursively(asyncReturnType, syncReturnType); + } + public override void AnalyzeCore(ISymbolAnalysisContext context) { INamedTypeSymbol type = (INamedTypeSymbol)context.Symbol; @@ -304,6 +390,16 @@ public override void AnalyzeCore(ISymbolAnalysisContext context) } else { + if (!DoReturnTypesMatch(methodSymbol.ReturnType, syncMember.ReturnType)) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0004, member.Locations.First()), member); + } + + if (!methodSymbol.Parameters.SequenceEqual(syncMember.Parameters, ParameterEquivalenceComparer.Default)) + { + context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0005, member.Locations.First()), member); + } + CheckClientMethod(context, syncMember); } }