diff --git a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java index 30d28769a..f6703c82b 100644 --- a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java +++ b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java @@ -47,15 +47,16 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p) } if ((STRING_EQUALS.matches(m) || (!Boolean.TRUE.equals(style.getIgnoreEqualsIgnoreCase()) && STRING_EQUALS_IGNORE_CASE.matches(m))) && - m.getArguments().get(0) instanceof J.Literal && - !(m.getSelect() instanceof J.Literal)) { + m.getArguments().get(0) instanceof J.Literal && + m.getArguments().get(0).getType() != JavaType.Primitive.Null && + !(m.getSelect() instanceof J.Literal)) { Tree parent = getCursor().getParentTreeCursor().getValue(); if (parent instanceof J.Binary) { J.Binary binary = (J.Binary) parent; if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) { J.Binary potentialNullCheck = (J.Binary) binary.getLeft(); if ((isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), m.getSelect())) || - (isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), m.getSelect()))) { + (isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), m.getSelect()))) { doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary)); } } diff --git a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java index 5f87cbe42..bc370a862 100644 --- a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java @@ -84,4 +84,19 @@ public class A { ) ); } + + @Test + void nullLiteral() { + rewriteRun( + //language=java + java(""" + public class A { + void foo(String s) { + if(s.equals(null)) { + } + } + } + """) + ); + } }