Skip to content

Commit

Permalink
Improve refactoring "Deconstruct foreach variable" (RR0217)
Browse files Browse the repository at this point in the history
  • Loading branch information
josefpihrt committed Mar 27, 2022
1 parent d44b21d commit db53df2
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Josef Pihrt and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
Expand All @@ -21,105 +22,136 @@ public static void ComputeRefactoring(
{
ITypeSymbol typeSymbol = semanticModel.GetTypeSymbol(forEachStatement.Type);

IMethodSymbol deconstructSymbol = typeSymbol.FindMember<IMethodSymbol>(
"Deconstruct",
symbol =>
{
if (symbol.DeclaredAccessibility == Accessibility.Public)
IEnumerable<ISymbol> parameters = null;

if (typeSymbol.IsTupleType)
{
var tupleType = (INamedTypeSymbol)typeSymbol;
parameters = tupleType.TupleElements;
}
else
{
IMethodSymbol deconstructSymbol = typeSymbol.FindMember<IMethodSymbol>(
"Deconstruct",
symbol =>
{
ImmutableArray<IParameterSymbol> parameters = symbol.Parameters;
if (symbol.DeclaredAccessibility == Accessibility.Public)
{
ImmutableArray<IParameterSymbol> parameters = symbol.Parameters;

return parameters.Any()
&& parameters.All(f => f.RefKind == RefKind.Out);
}
return parameters.Any()
&& parameters.All(f => f.RefKind == RefKind.Out);
}

return false;
});
return false;
});

if (deconstructSymbol is null)
return;
if (deconstructSymbol is null)
return;

parameters = deconstructSymbol.Parameters;
}

ISymbol foreachSymbol = semanticModel.GetDeclaredSymbol(forEachStatement, context.CancellationToken);

if (foreachSymbol?.IsKind(SymbolKind.Local) != true)
return;

var walker = new DeconstructForeachVariableWalker(
deconstructSymbol,
parameters,
foreachSymbol,
forEachStatement.Identifier.ValueText,
semanticModel,
context.CancellationToken);

walker.Visit(forEachStatement.Statement);

if (!walker.Success)
return;

context.RegisterRefactoring(
"Deconstruct foreach variable",
ct => RefactorAsync(context.Document, forEachStatement, deconstructSymbol, foreachSymbol, semanticModel, ct),
RefactoringDescriptors.DeconstructForeachVariable);
if (walker.Success)
{
context.RegisterRefactoring(
"Deconstruct foreach variable",
ct => RefactorAsync(context.Document, forEachStatement, parameters, foreachSymbol, semanticModel, ct),
RefactoringDescriptors.DeconstructForeachVariable);
}
}

private static async Task<Document> RefactorAsync(
Document document,
ForEachStatementSyntax forEachStatement,
IMethodSymbol deconstructSymbol,
IEnumerable<ISymbol> deconstructSymbols,
ISymbol identifierSymbol,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
int position = forEachStatement.SpanStart;
ITypeSymbol elementType = semanticModel.GetForEachStatementInfo(forEachStatement).ElementType;
SyntaxNode enclosingSymbolSyntax = semanticModel.GetEnclosingSymbolSyntax(position, cancellationToken);

ImmutableArray<ISymbol> declaredSymbols = semanticModel.GetDeclaredSymbols(enclosingSymbolSyntax, excludeAnonymousTypeProperty: true, cancellationToken);

ImmutableArray<ISymbol> symbols = declaredSymbols
.Concat(semanticModel.LookupSymbols(position))
.Distinct()
.Except(deconstructSymbols)
.ToImmutableArray();

Dictionary<string, string> newNames = deconstructSymbols
.Select(parameter =>
{
string name = StringUtility.FirstCharToLower(parameter.Name);
string newName = NameGenerator.Default.EnsureUniqueName(name, symbols);

return (name: parameter.Name, newName);
})
.ToDictionary(f => f.name, f => f.newName);

var rewriter = new DeconstructForeachVariableRewriter(identifierSymbol, newNames, semanticModel, cancellationToken);

var newStatement = (StatementSyntax)rewriter.Visit(forEachStatement.Statement);

DeclarationExpressionSyntax variableExpression = DeclarationExpression(
CSharpFactory.VarType().WithTriviaFrom(forEachStatement.Type),
ParenthesizedVariableDesignation(
deconstructSymbol.Parameters.Select(parameter =>
deconstructSymbols.Select(parameter =>
{
return (VariableDesignationSyntax)SingleVariableDesignation(
Identifier(
SyntaxTriviaList.Empty,
parameter.Name,
SyntaxTriviaList.Empty));
return SingleVariableDesignation(
Identifier(SyntaxTriviaList.Empty, newNames[parameter.Name], SyntaxTriviaList.Empty));
})
.ToSeparatedSyntaxList())
.ToSeparatedSyntaxList<VariableDesignationSyntax>())
.WithTriviaFrom(forEachStatement.Identifier))
.WithFormatterAnnotation();

var rewriter = new DeconstructForeachVariableRewriter(identifierSymbol, semanticModel, cancellationToken);

var newStatement = (StatementSyntax)rewriter.Visit(forEachStatement.Statement);

ForEachVariableStatementSyntax newForEachStatement = ForEachVariableStatement(
ForEachVariableStatementSyntax forEachVariableStatement = ForEachVariableStatement(
forEachStatement.AttributeLists,
forEachStatement.AwaitKeyword,
forEachStatement.ForEachKeyword,
forEachStatement.OpenParenToken,
variableExpression.WithFormatterAnnotation(),
variableExpression,
forEachStatement.InKeyword,
forEachStatement.Expression,
forEachStatement.CloseParenToken,
newStatement);

return await document.ReplaceNodeAsync(forEachStatement, newForEachStatement, cancellationToken).ConfigureAwait(false);
return await document.ReplaceNodeAsync(forEachStatement, forEachVariableStatement, cancellationToken).ConfigureAwait(false);
}

private class DeconstructForeachVariableWalker : CSharpSyntaxWalker
{
public DeconstructForeachVariableWalker(
IMethodSymbol deconstructMethod,
IEnumerable<ISymbol> parameters,
ISymbol identifierSymbol,
string identifier,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
DeconstructMethod = deconstructMethod;
Parameters = parameters;
IdentifierSymbol = identifierSymbol;
Identifier = identifier;
SemanticModel = semanticModel;
CancellationToken = cancellationToken;
}

public IMethodSymbol DeconstructMethod { get; }
public IEnumerable<ISymbol> Parameters { get; }

public ISymbol IdentifierSymbol { get; }

Expand Down Expand Up @@ -155,7 +187,7 @@ bool IsFixable(IdentifierNameSyntax node)
var memberAccess = (MemberAccessExpressionSyntax)node.Parent;
if (object.ReferenceEquals(memberAccess.Expression, node))
{
foreach (IParameterSymbol parameter in DeconstructMethod.Parameters)
foreach (ISymbol parameter in Parameters)
{
if (string.Equals(parameter.Name, memberAccess.Name.Identifier.ValueText, StringComparison.OrdinalIgnoreCase))
return true;
Expand All @@ -172,16 +204,20 @@ private class DeconstructForeachVariableRewriter : CSharpSyntaxRewriter
{
public DeconstructForeachVariableRewriter(
ISymbol identifierSymbol,
Dictionary<string, string> names,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
IdentifierSymbol = identifierSymbol;
Names = names;
SemanticModel = semanticModel;
CancellationToken = cancellationToken;
}

public ISymbol IdentifierSymbol { get; }

public Dictionary<string, string> Names { get; }

public SemanticModel SemanticModel { get; }

public CancellationToken CancellationToken { get; }
Expand All @@ -193,8 +229,12 @@ public override SyntaxNode VisitMemberAccessExpression(MemberAccessExpressionSyn
&& identifierName.Identifier.ValueText == IdentifierSymbol.Name
&& SymbolEqualityComparer.Default.Equals(SemanticModel.GetSymbol(identifierName, CancellationToken), IdentifierSymbol))
{
return IdentifierName(StringUtility.FirstCharToLower(node.Name.Identifier.ValueText))
.WithTriviaFrom(identifierName);
string name = node.Name.Identifier.ValueText;

if (!Names.TryGetValue(name, out string newName))
newName = StringUtility.FirstCharToLower(name);

return IdentifierName(newName).WithTriviaFrom(identifierName);
}

return base.VisitMemberAccessExpression(node);
Expand Down
106 changes: 105 additions & 1 deletion src/Tests/Refactorings.Tests/RR0217DeconstructForeachVariableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public class RR0217DeconstructForeachVariableTests : AbstractCSharpRefactoringVe
public override string RefactoringId { get; } = RefactoringIdentifiers.DeconstructForeachVariable;

[Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)]
public async Task Test_EmptyObjectInitializer()
public async Task Test_Dictionary()
{
await VerifyRefactoringAsync(@"
using System.Collections.Generic;
Expand Down Expand Up @@ -45,6 +45,110 @@ void M()
}
}
}
", equivalenceKey: EquivalenceKey.Create(RefactoringId));
}

[Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)]
public async Task Test_Dictionary_TopLevelStatement()
{
await VerifyRefactoringAsync(@"
using System.Collections.Generic;
var dic = new Dictionary<object, object>();
foreach ([||]var kvp in dic)
{
var k = kvp.Key;
var v = kvp.Value.ToString();
}
", @"
using System.Collections.Generic;
var dic = new Dictionary<object, object>();
foreach (var (key, value) in dic)
{
var k = key;
var v = value.ToString();
}
", equivalenceKey: EquivalenceKey.Create(RefactoringId));
}

[Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)]
public async Task Test_Tuple()
{
await VerifyRefactoringAsync(@"
using System.Collections.Generic;
class C
{
void M()
{
var items = new List<(object, string)>();
foreach ([||]var item in items)
{
var k = item.Item1;
var v = item.Item2.ToString();
}
}
}
", @"
using System.Collections.Generic;
class C
{
void M()
{
var items = new List<(object, string)>();
foreach (var (item1, item2) in items)
{
var k = item1;
var v = item2.ToString();
}
}
}
", equivalenceKey: EquivalenceKey.Create(RefactoringId));
}

[Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)]
public async Task Test_TupleWithNamedFields()
{
await VerifyRefactoringAsync(@"
using System.Collections.Generic;
class C
{
void M()
{
var p1 = false;
var items = new List<(object p1, string p2)>();
foreach ([||]var item in items)
{
var k = item.p1;
var v = item.p2.ToString();
}
}
}
", @"
using System.Collections.Generic;
class C
{
void M()
{
var p1 = false;
var items = new List<(object p1, string p2)>();
foreach (var (p12, p2) in items)
{
var k = p12;
var v = p2.ToString();
}
}
}
", equivalenceKey: EquivalenceKey.Create(RefactoringId));
}
}
Expand Down

0 comments on commit db53df2

Please sign in to comment.