Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AZC0004 ensure that method signatures match #9388

Original file line number Diff line number Diff line change
Expand Up @@ -396,5 +396,60 @@ public virtual Response Get(CancellationToken cancellationToken)
await Verifier.CreateAnalyzer(code)
.RunAsync();
}

[Fact]
public async Task AZC0004ProducedForMethodsWithMismatchedReturnTypes()
{
const string code = @"
using System.Threading;
using System.Threading.Tasks;

namespace RandomNamespace
{
public class SomeClient
{
public virtual Task<string> {|AZC0004:GetAsync|}(CancellationToken cancellationToken = default)
{
return null;
}

public virtual int Get(CancellationToken cancellationToken = default)
{
return 0;
}
}
}";
await Verifier.CreateAnalyzer(code)
.WithDisabledDiagnostics("AZC0015")
.RunAsync();
}

[Fact]
public async Task AZC0004ProducedForMethotsWithMismatchedParameters()
{
const string code = @"
using System.Threading;
using System.Threading.Tasks;

namespace RandomNamespace
{
public class SomeClient
{
public virtual Task<string> {|AZC0004:GetAsync|}(CancellationToken cancellationToken = default)
{
return null;
}

public virtual string Get(string foo, CancellationToken cancellationToken = default)
{
return null;
}
}
}";
await Verifier.CreateAnalyzer(code)
.WithDisabledDiagnostics("AZC0015")
.RunAsync();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,92 @@ private static bool IsClientMethodReturnType(ISymbolAnalysisContext context, IMe
return false;
}

private static bool CompareReturnTypesRecursively(ITypeSymbol asyncType, ITypeSymbol syncType)
{
// Unwrap Task<T> for async methods
if (asyncType is INamedTypeSymbol asyncNamedType && asyncNamedType.Name == TaskTypeName)
{
if (asyncNamedType.IsGenericType)
{
asyncType = asyncNamedType.TypeArguments.Single();
}
else
{
// async returns Task, sync should return void or non-generic Task
if (syncType.SpecialType == SpecialType.System_Void ||
(syncType is INamedTypeSymbol syncNamedType && syncNamedType.Name == TaskTypeName && !syncNamedType.IsGenericType))
{
return true;
}
return false;
}
}

// Map AsyncPageable<T> to Pageable<T> for easy comparison since they are equivalent for these purposes
if (asyncType is INamedTypeSymbol asyncTypeSymbol && asyncTypeSymbol.Name == AsyncPageableTypeName && asyncTypeSymbol.IsGenericType)
{
asyncType = asyncTypeSymbol.ContainingNamespace.GetTypeMembers(PageableTypeName).FirstOrDefault()?.Construct(asyncTypeSymbol.TypeArguments.ToArray());
}

// Compare directly if sync method return type is not a named type symbol
if (syncType is not INamedTypeSymbol syncTypeSymbol)
{
return SymbolEqualityComparer.Default.Equals(asyncType, syncType);
}

// Compare type names and namespaces
if (asyncType is INamedTypeSymbol asyncNamedTypeSymbol && syncType is INamedTypeSymbol syncNamedTypeSymbol)
{
if (asyncNamedTypeSymbol.Name != syncNamedTypeSymbol.Name ||
asyncNamedTypeSymbol.ContainingNamespace.ToDisplayString() != syncNamedTypeSymbol.ContainingNamespace.ToDisplayString())
{
return false;
}

// Compare nested types recursively
if (asyncNamedTypeSymbol.IsGenericType && syncNamedTypeSymbol.IsGenericType)
{
var asyncTypeArguments = asyncNamedTypeSymbol.TypeArguments;
var syncTypeArguments = syncNamedTypeSymbol.TypeArguments;

if (asyncTypeArguments.Length != syncTypeArguments.Length)
{
return false;
}

for (int i = 0; i < asyncTypeArguments.Length; i++)
{
if (!CompareReturnTypesRecursively(asyncTypeArguments[i], syncTypeArguments[i]))
{
return false;
}
}
return true;
}
else if ((!asyncNamedTypeSymbol.IsGenericType) && (!syncNamedTypeSymbol.IsGenericType))
{
return true;
}
else
{
// One is a generic type and the other is not
return false;
}
}

return SymbolEqualityComparer.Default.Equals(asyncType, syncType);
}

private static bool DoReturnTypesMatch(ITypeSymbol asyncReturnType, ITypeSymbol syncReturnType)
{
if (asyncReturnType == null || syncReturnType == null)
{
return false;
}

return CompareReturnTypesRecursively(asyncReturnType, syncReturnType);
}

public override void AnalyzeCore(ISymbolAnalysisContext context)
{
INamedTypeSymbol type = (INamedTypeSymbol)context.Symbol;
Expand All @@ -304,6 +390,16 @@ public override void AnalyzeCore(ISymbolAnalysisContext context)
}
else
{
if (!DoReturnTypesMatch(methodSymbol.ReturnType, syncMember.ReturnType))
{
context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0004, member.Locations.First()), member);
}

if (!methodSymbol.Parameters.SequenceEqual(syncMember.Parameters, ParameterEquivalenceComparer.Default))
{
context.ReportDiagnostic(Diagnostic.Create(Descriptors.AZC0005, member.Locations.First()), member);
}

CheckClientMethod(context, syncMember);
}
}
Expand Down
Loading