diff --git a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java index f6703c82b..f2d6bfb2b 100644 --- a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java +++ b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java @@ -19,36 +19,35 @@ import lombok.Value; import org.openrewrite.Tree; import org.openrewrite.internal.lang.Nullable; -import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.style.EqualsAvoidsNullStyle; -import org.openrewrite.java.tree.Expression; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaType; -import org.openrewrite.java.tree.Space; +import org.openrewrite.java.tree.*; +import org.openrewrite.marker.Markers; import static java.util.Collections.singletonList; @Value @EqualsAndHashCode(callSuper = false) -public class EqualsAvoidsNullVisitor
extends JavaIsoVisitor
{ +public class EqualsAvoidsNullVisitor
extends JavaVisitor
{ private static final MethodMatcher STRING_EQUALS = new MethodMatcher("String equals(java.lang.Object)"); private static final MethodMatcher STRING_EQUALS_IGNORE_CASE = new MethodMatcher("String equalsIgnoreCase(java.lang.String)"); EqualsAvoidsNullStyle style; @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p) { - J.MethodInvocation m = super.visitMethodInvocation(method, p); - + public J visitMethodInvocation(J.MethodInvocation method, P p) { + J j = super.visitMethodInvocation(method, p); + if (!(j instanceof J.MethodInvocation)) { + return j; + } + J.MethodInvocation m = (J.MethodInvocation) j; if (m.getSelect() == null) { return m; } if ((STRING_EQUALS.matches(m) || (!Boolean.TRUE.equals(style.getIgnoreEqualsIgnoreCase()) && STRING_EQUALS_IGNORE_CASE.matches(m))) && m.getArguments().get(0) instanceof J.Literal && - m.getArguments().get(0).getType() != JavaType.Primitive.Null && !(m.getSelect() instanceof J.Literal)) { Tree parent = getCursor().getParentTreeCursor().getValue(); if (parent instanceof J.Binary) { @@ -62,8 +61,16 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p) } } - m = m.withSelect(((J.Literal) m.getArguments().get(0)).withPrefix(m.getSelect().getPrefix())) - .withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY))); + if (m.getArguments().get(0).getType() == JavaType.Primitive.Null) { + return new J.Binary(Tree.randomId(), m.getPrefix(), Markers.EMPTY, + m.getSelect(), + JLeftPadded.build(J.Binary.Type.Equal).withBefore(Space.SINGLE_SPACE), + m.getArguments().get(0).withPrefix(Space.SINGLE_SPACE), + JavaType.Primitive.Boolean); + } else { + m = m.withSelect(((J.Literal) m.getArguments().get(0)).withPrefix(m.getSelect().getPrefix())) + .withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY))); + } } return m; diff --git a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java index bc370a862..68bb55236 100644 --- a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java @@ -96,6 +96,15 @@ void foo(String s) { } } } + """, + """ + + public class A { + void foo(String s) { + if(s == null) { + } + } + } """) ); }