Skip to content

Commit

Permalink
EqualsAvoidsNull should flip arguments for constants (#398)
Browse files Browse the repository at this point in the history
* replaceMethodArgs

* replaceMethodArgs

* replaceMethodArgs

* replaceMethodArgs

* replaceMethodArg

* undo

* undo

* add String foo, String bar

* multiple

* add replaceMethodArg

* Also place field accesses first

* Check flags on fieldAccess.name.fieldType

* Add test showing no change when not static & final

* Also support static imports

* Remove unused import

---------

Co-authored-by: Vincent Potucek <[email protected]>
Co-authored-by: Tim te Beek <[email protected]>
  • Loading branch information
3 people authored Nov 29, 2024
1 parent 0b00944 commit 0d8d873
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.jspecify.annotations.Nullable;
import org.openrewrite.*;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.style.Checkstyle;
import org.openrewrite.java.style.EqualsAvoidsNullStyle;
import org.openrewrite.java.tree.J;
Expand Down Expand Up @@ -53,7 +54,7 @@ public Duration getEstimatedEffortPerOccurrence() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return new JavaIsoVisitor<ExecutionContext>() {
JavaIsoVisitor<ExecutionContext> replacementVisitor = new JavaIsoVisitor<ExecutionContext>() {
@Override
public J visit(@Nullable Tree tree, ExecutionContext ctx) {
if (tree instanceof JavaSourceFile) {
Expand All @@ -68,5 +69,12 @@ public J visit(@Nullable Tree tree, ExecutionContext ctx) {
return (J) tree;
}
};
return Preconditions.check(
Preconditions.or(
new UsesMethod<>("java.lang.String equals*(..)"),
new UsesMethod<>("java.lang.String co*(..)")
),
replacementVisitor
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,37 @@ public class EqualsAvoidsNullVisitor<P> extends JavaVisitor<P> {
@Override
public J visitMethodInvocation(J.MethodInvocation method, P p) {
J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, p);
if (m.getSelect() != null &&
!(m.getSelect() instanceof J.Literal) &&
!m.getArguments().isEmpty() &&
m.getArguments().get(0) instanceof J.Literal &&
isStringComparisonMethod(m)) {
return literalsFirstInComparisonsBinaryCheck(m, getCursor().getParentTreeCursor().getValue());
if (m.getSelect() != null && !(m.getSelect() instanceof J.Literal) &&
isStringComparisonMethod(m) && hasCompatibleArgument(m)) {

maybeHandleParentBinary(m);

Expression firstArgument = m.getArguments().get(0);
return firstArgument.getType() == JavaType.Primitive.Null ?
literalsFirstInComparisonsNull(m, firstArgument) :
literalsFirstInComparisons(m, firstArgument);
}
return m;
}

private boolean hasCompatibleArgument(J.MethodInvocation m) {
if (m.getArguments().isEmpty()) {
return false;
}
Expression firstArgument = m.getArguments().get(0);
if (firstArgument instanceof J.Literal) {
return true;
}
if (firstArgument instanceof J.FieldAccess) {
firstArgument = ((J.FieldAccess) firstArgument).getName();
}
if (firstArgument instanceof J.Identifier) {
JavaType.Variable fieldType = ((J.Identifier) firstArgument).getFieldType();
return fieldType != null && fieldType.hasFlags(Flag.Static, Flag.Final);
}
return false;
}

private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) {
return EQUALS.matches(methodInvocation) ||
!style.getIgnoreEqualsIgnoreCase() &&
Expand All @@ -76,17 +97,26 @@ private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) {
CONTENT_EQUALS.matches(methodInvocation);
}

private Expression literalsFirstInComparisonsBinaryCheck(J.MethodInvocation m, P parent) {
private void maybeHandleParentBinary(J.MethodInvocation m) {
P parent = getCursor().getParentTreeCursor().getValue();
if (parent instanceof J.Binary) {
handleBinaryExpression(m, (J.Binary) parent);
if (((J.Binary) parent).getOperator() == J.Binary.Type.And && ((J.Binary) parent).getLeft() instanceof J.Binary) {
J.Binary potentialNullCheck = (J.Binary) ((J.Binary) parent).getLeft();
if (isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), requireNonNull(m.getSelect())) ||
isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), requireNonNull(m.getSelect()))) {
doAfterVisit(new RemoveUnnecessaryNullCheck<>((J.Binary) parent));
}
}
}
return getExpression(m, m.getArguments().get(0));
}

private static Expression getExpression(J.MethodInvocation m, Expression firstArgument) {
return firstArgument.getType() == JavaType.Primitive.Null ?
literalsFirstInComparisonsNull(m, firstArgument) :
literalsFirstInComparisons(m, firstArgument);
private boolean isNullLiteral(Expression expression) {
return expression instanceof J.Literal && ((J.Literal) expression).getType() == JavaType.Primitive.Null;
}

private boolean matchesSelect(Expression expression, Expression select) {
return expression.printTrimmed(getCursor()).replaceAll("\\s", "")
.equals(select.printTrimmed(getCursor()).replaceAll("\\s", ""));
}

private static J.Binary literalsFirstInComparisonsNull(J.MethodInvocation m, Expression firstArgument) {
Expand All @@ -104,25 +134,6 @@ private static J.MethodInvocation literalsFirstInComparisons(J.MethodInvocation
.withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY)));
}

private void handleBinaryExpression(J.MethodInvocation m, J.Binary binary) {
if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) {
J.Binary potentialNullCheck = (J.Binary) binary.getLeft();
if (isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), requireNonNull(m.getSelect())) ||
isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), requireNonNull(m.getSelect()))) {
doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary));
}
}
}

private boolean isNullLiteral(Expression expression) {
return expression instanceof J.Literal && ((J.Literal) expression).getType() == JavaType.Primitive.Null;
}

private boolean matchesSelect(Expression expression, Expression select) {
return expression.printTrimmed(getCursor()).replaceAll("\\s", "")
.equals(select.printTrimmed(getCursor()).replaceAll("\\s", ""));
}

private static class RemoveUnnecessaryNullCheck<P> extends JavaVisitor<P> {

private final J.Binary scope;
Expand Down
131 changes: 127 additions & 4 deletions src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
*/
package org.openrewrite.staticanalysis;

import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.openrewrite.DocumentExample;
import org.openrewrite.Issue;
import org.openrewrite.test.RecipeSpec;
import org.openrewrite.test.RewriteTest;

Expand Down Expand Up @@ -94,17 +96,16 @@ public class A {
@Test
void nullLiteral() {
rewriteRun(
//language=java
java("""
//language=java
java("""
public class A {
void foo(String s) {
if(s.equals(null)) {
}
}
}
""",
"""
"""
public class A {
void foo(String s) {
if(s == null) {
Expand All @@ -114,4 +115,126 @@ void foo(String s) {
""")
);
}

@Nested
class ReplaceConstantMethodArg {

@Issue("https://github.com/openrewrite/rewrite-static-analysis/pull/398")
@Test
void one() {
rewriteRun(
// language=java
java(
"""
public class Constants {
public static final String FOO = "FOO";
}
class A {
private boolean isFoo(String foo) {
return foo.contentEquals(Constants.FOO);
}
}
""",
"""
public class Constants {
public static final String FOO = "FOO";
}
class A {
private boolean isFoo(String foo) {
return Constants.FOO.contentEquals(foo);
}
}
"""
)
);
}

@Test
void staticImport() {
rewriteRun(
// language=java
java(
"""
package c;
public class Constants {
public static final String FOO = "FOO";
}
"""
),
// language=java
java(
"""
import static c.Constants.FOO;
class A {
private boolean isFoo(String foo) {
return foo.contentEquals(FOO);
}
}
""",
"""
import static c.Constants.FOO;
class A {
private boolean isFoo(String foo) {
return FOO.contentEquals(foo);
}
}
"""
)
);
}

@Test
void multiple() {
rewriteRun(
//language=java
java(
"""
public class Constants {
public static final String FOO = "FOO";
}
class A {
private boolean isFoo(String foo, String bar) {
return foo.contentEquals(Constants.FOO)
|| bar.compareToIgnoreCase(Constants.FOO);
}
}
""",
"""
public class Constants {
public static final String FOO = "FOO";
}
class A {
private boolean isFoo(String foo, String bar) {
return Constants.FOO.contentEquals(foo)
|| Constants.FOO.compareToIgnoreCase(bar);
}
}
"""
)
);
}

@Test
void nonStaticNonFinalNoChange() {
rewriteRun(
// language=java
java(
"""
public class Constants {
public final String FOO = "FOO";
public static String BAR = "BAR";
}
class A {
private boolean isFoo(String foo) {
return foo.contentEquals(new Constants().FOO);
}
private boolean isBar(String bar) {
return bar.contentEquals(Constants.BAR);
}
}
"""
)
);
}
}
}

0 comments on commit 0d8d873

Please sign in to comment.