Skip to content

Commit

Permalink
SemanticallyEqual should allow for different lambda parameter names (
Browse files Browse the repository at this point in the history
…#4494)

* Fix minor test bug

Caused a minor test attribution error.

* `SemanticallyEqual` should allow for different lambda parameter names

While comparing two lambdas for semantical equality, the lambda parameter names should be ignored. This actually also applies to other variables (e.g. variables declared in a block such as a lambda body), but that is out of scope for this PR.

This will allow Refaster recipes to match a lambda against other lambdas using different parameter names.
  • Loading branch information
knutwannheden authored Sep 15, 2024
1 parent 0b6914b commit 0535297
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,28 @@ class T {
);
}

@Test
void lambdaParameterNames() {
assertExpressionsEqual(
"""
import java.util.Comparator;
class T {
Comparator<Integer> a = (x1, y1) -> x1 - y1;
Comparator<Integer> b = (x2, y2) -> x2 - y2;
}
"""
);
assertExpressionsNotEqual(
"""
import java.util.Comparator;
class T {
Comparator<Integer> a = (x1, y1) -> x1 - y1;
Comparator<Integer> b = (x2, y2) -> y2 - x2;
}
"""
);
}

@Nested
class Generics {
@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void classAnnotations() {
java(
"""
import javax.annotation.processing.Generated;
@SuppressWarnings("all")
public @Generated("foo") class T {}
""",
Expand Down Expand Up @@ -68,7 +68,7 @@ void annotatedType() {
import java.lang.annotation.*;
import static java.lang.annotation.ElementType.*;
class T {
public @A1 Integer @A2 [] arg;
}
Expand Down Expand Up @@ -159,9 +159,9 @@ void fieldAccessAnnotations() {
import java.lang.annotation.*;
import static java.lang.annotation.ElementType.*;
class T {
java. lang. @Ann Map arg;
java. lang. @Ann Integer arg;
}
@Retention(RetentionPolicy.RUNTIME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,22 @@ private static J[] createTemplateParameters(String code) {
substituted = propertyPlaceholderHelper.replacePlaceholders(substituted, key -> {
String s;
if (!key.isEmpty()) {
TemplateParameterParser parser = new TemplateParameterParser(new CommonTokenStream(new TemplateParameterLexer(
CharStreams.fromString(key))));

parser.removeErrorListeners();
parser.addErrorListener(new BaseErrorListener() {
BaseErrorListener errorListener = new BaseErrorListener() {
@Override
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol,
int line, int charPositionInLine, String msg, RecognitionException e) {
throw new IllegalArgumentException(
String.format("Syntax error at line %d:%d %s.", line, charPositionInLine, msg), e);
}
});
};

TemplateParameterLexer lexer = new TemplateParameterLexer(CharStreams.fromString(key));
lexer.removeErrorListeners();
lexer.addErrorListener(errorListener);

TemplateParameterParser parser = new TemplateParameterParser(new CommonTokenStream(lexer));
parser.removeErrorListeners();
parser.addErrorListener(errorListener);

TemplateParameterParser.MatcherPatternContext ctx = parser.matcherPattern();
if (ctx.typedPattern() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.tree.*;

import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;

/**
Expand Down Expand Up @@ -59,6 +57,8 @@ protected static class SemanticallyEqualVisitor extends JavaIsoVisitor<J> {
private final boolean compareMethodArguments;

protected final AtomicBoolean isEqual = new AtomicBoolean(true);
private final Deque<Map<String, String>> variableScope = new ArrayDeque<>();
private final Set<JavaType> seen = new HashSet<>();

public SemanticallyEqualVisitor(boolean compareMethodArguments) {
this.compareMethodArguments = compareMethodArguments;
Expand Down Expand Up @@ -120,6 +120,29 @@ protected void visitList(@Nullable List<? extends J> list1, @Nullable List<? ext
return (J) tree;
}

@Override
public @Nullable J preVisit(J tree, J j) {
if (declaresVariableScope(tree)) {
variableScope.push(new HashMap<>());
}
return tree;
}

@Override
public @Nullable J postVisit(J tree, J j) {
if (declaresVariableScope(tree)) {
variableScope.pop();
}
return tree;
}

protected boolean declaresVariableScope(J tree) {
if (tree instanceof J.Lambda) {
return true;
}
return false;
}

@Override
public Expression visitExpression(Expression expression, J j) {
if (isEqual.get()) {
Expand Down Expand Up @@ -683,6 +706,16 @@ public J.Identifier visitIdentifier(J.Identifier identifier, J j) {
}

J.Identifier compareTo = (J.Identifier) j;
if (identifier.getFieldType() != null) {
Map<String, String> scope = variableScope.peek();
if (scope != null && scope.containsKey(identifier.getSimpleName()) && scope.get(identifier.getSimpleName()).equals(compareTo.getSimpleName())) {
return identifier;
}
}
if (TypeUtils.isWellFormedType(identifier.getType(), seen) && !TypeUtils.isOfType(identifier.getType(), compareTo.getType())) {
isEqual.set(false);
return identifier;
}
if (!identifier.getSimpleName().equals(compareTo.getSimpleName())) {
isEqual.set(false);
return identifier;
Expand Down Expand Up @@ -784,8 +817,8 @@ public J.Lambda visitLambda(J.Lambda lambda, J j) {
isEqual.set(false);
return lambda;
}
visitList(lambda.getParameters().getParameters(), compareTo.getParameters().getParameters());
visit(lambda.getBody(), compareTo.getBody());
this.visitList(lambda.getParameters().getParameters(), compareTo.getParameters().getParameters());
}
return lambda;
}
Expand Down Expand Up @@ -911,7 +944,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, J j)
JavaType.FullyQualified methodDeclaringType = method.getMethodType().getDeclaringType();
JavaType.FullyQualified compareToDeclaringType = compareTo.getMethodType().getDeclaringType();
if (!TypeUtils.isAssignableTo(methodDeclaringType instanceof JavaType.Parameterized ?
((JavaType.Parameterized) methodDeclaringType).getType() : methodDeclaringType,
((JavaType.Parameterized) methodDeclaringType).getType() : methodDeclaringType,
compareToDeclaringType instanceof JavaType.Parameterized ?
((JavaType.Parameterized) compareToDeclaringType).getType() : compareToDeclaringType)) {
isEqual.set(false);
Expand Down Expand Up @@ -1360,8 +1393,14 @@ public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations
}

J.VariableDeclarations.NamedVariable compareTo = (J.VariableDeclarations.NamedVariable) j;
if (!variable.getSimpleName().equals(compareTo.getSimpleName()) ||
!TypeUtils.isOfType(variable.getType(), compareTo.getType()) ||
Map<String, String> scope = variableScope.peek();
if (scope != null) {
scope.put(variable.getSimpleName(), compareTo.getSimpleName());
} else if (!variable.getSimpleName().equals(compareTo.getSimpleName())) {
isEqual.set(false);
return variable;
}
if (!TypeUtils.isOfType(variable.getType(), compareTo.getType()) ||
nullMissMatch(variable.getInitializer(), compareTo.getInitializer())) {
isEqual.set(false);
return variable;
Expand Down

0 comments on commit 0535297

Please sign in to comment.