diff --git a/rewrite-java/build.gradle.kts b/rewrite-java/build.gradle.kts index f84627b4c2b..92121e1ec1b 100644 --- a/rewrite-java/build.gradle.kts +++ b/rewrite-java/build.gradle.kts @@ -64,6 +64,7 @@ dependencies { testRuntimeOnly(project(":rewrite-java-17")) testImplementation("com.tngtech.archunit:archunit:1.0.1") testImplementation("com.tngtech.archunit:archunit-junit5:1.0.1") + testImplementation("org.junit-pioneer:junit-pioneer:2.0.0") // For use in ClassGraphTypeMappingTest testRuntimeOnly("org.eclipse.persistence:org.eclipse.persistence.core:3.0.2") diff --git a/rewrite-java/src/main/java/org/openrewrite/java/RemoveMethodInvocationsVisitor.java b/rewrite-java/src/main/java/org/openrewrite/java/RemoveMethodInvocationsVisitor.java new file mode 100644 index 00000000000..7cf86789aeb --- /dev/null +++ b/rewrite-java/src/main/java/org/openrewrite/java/RemoveMethodInvocationsVisitor.java @@ -0,0 +1,256 @@ +/* + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java; + +import lombok.Value; +import lombok.With; +import org.jspecify.annotations.Nullable; +import org.openrewrite.*; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.tree.*; +import org.openrewrite.marker.Marker; + +import java.util.*; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static org.openrewrite.Tree.randomId; + +/** + * This visitor removes method calls matching some criteria. + * Tries to intelligently remove within chains without breaking other methods in the chain. + */ +public class RemoveMethodInvocationsVisitor extends JavaVisitor { + private final Map>> matchers; + + public RemoveMethodInvocationsVisitor(Map>> matchers) { + this.matchers = matchers; + } + + public RemoveMethodInvocationsVisitor(List methodSignatures) { + this(methodSignatures.stream().collect(Collectors.toMap( + MethodMatcher::new, + signature -> args -> true + ))); + } + + @Override + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); + + if (inMethodCallChain()) { + List newArgs = ListUtils.map(m.getArguments(), arg -> (Expression) this.visit(arg, ctx)); + return m.withArguments(newArgs); + } + + J j = removeMethods(m, 0, isLambdaBody(), new Stack<>()); + if (j != null) { + j = j.withPrefix(m.getPrefix()); + // There should always be + if (!m.getArguments().isEmpty() && m.getArguments().stream().allMatch(ToBeRemoved::hasMarker)) { + return ToBeRemoved.withMarker(j); + } + } + + //noinspection DataFlowIssue allow returning null to remove the element + return j; + } + + private @Nullable J removeMethods(Expression expression, int depth, boolean isLambdaBody, Stack selectAfter) { + if (!(expression instanceof J.MethodInvocation)) { + return expression; + } + + boolean isStatement = isStatement(); + J.MethodInvocation m = (J.MethodInvocation) expression; + + if (m.getMethodType() == null || m.getSelect() == null) { + return expression; + } + + if (matchers.entrySet().stream().anyMatch(entry -> matches(m, entry.getKey(), entry.getValue()))) { + boolean hasSameReturnType = TypeUtils.isAssignableTo(m.getMethodType().getReturnType(), m.getSelect().getType()); + boolean removable = (isStatement && depth == 0) || hasSameReturnType; + if (!removable) { + return expression; + } + + if (m.getSelect() instanceof J.Identifier || m.getSelect() instanceof J.NewClass) { + boolean keepSelect = depth != 0; + if (keepSelect) { + selectAfter.add(getSelectAfter(m)); + return m.getSelect(); + } else { + if (isStatement) { + return null; + } else if (isLambdaBody) { + return ToBeRemoved.withMarker(J.Block.createEmptyBlock()); + } else { + return m.getSelect(); + } + } + } else if (m.getSelect() instanceof J.MethodInvocation) { + return removeMethods(m.getSelect(), depth, isLambdaBody, selectAfter); + } + } + + J.MethodInvocation method = m.withSelect((Expression) removeMethods(m.getSelect(), depth + 1, isLambdaBody, selectAfter)); + + // inherit prefix + if (!selectAfter.isEmpty()) { + method = inheritSelectAfter(method, selectAfter); + } + + return method; + } + + private boolean matches(J.MethodInvocation m, MethodMatcher matcher, Predicate> argsMatches) { + return matcher.matches(m) && argsMatches.test(m.getArguments()); + } + + private boolean isStatement() { + return getCursor().dropParentUntil(p -> p instanceof J.Block || + p instanceof J.Assignment || + p instanceof J.VariableDeclarations.NamedVariable || + p instanceof J.Return || + p instanceof JContainer || + p == Cursor.ROOT_VALUE + ).getValue() instanceof J.Block; + } + + private boolean isLambdaBody() { + if (getCursor().getParent() == null) { + return false; + } + Object parent = getCursor().getParent().getValue(); + return parent instanceof J.Lambda && ((J.Lambda) parent).getBody() == getCursor().getValue(); + } + + private boolean inMethodCallChain() { + return getCursor().dropParentUntil(p -> !(p instanceof JRightPadded)).getValue() instanceof J.MethodInvocation; + } + + private J.MethodInvocation inheritSelectAfter(J.MethodInvocation method, Stack prefix) { + return (J.MethodInvocation) new JavaIsoVisitor() { + @Override + public @Nullable JRightPadded visitRightPadded(@Nullable JRightPadded right, + JRightPadded.Location loc, + ExecutionContext executionContext) { + if (right == null) return null; + return prefix.isEmpty() ? right : right.withAfter(prefix.pop()); + } + }.visitNonNull(method, new InMemoryExecutionContext()); + } + + private Space getSelectAfter(J.MethodInvocation method) { + return new JavaIsoVisitor>() { + @Override + public @Nullable JRightPadded visitRightPadded(@Nullable JRightPadded right, + JRightPadded.Location loc, + List selectAfter) { + if (selectAfter.isEmpty()) { + selectAfter.add(right == null ? Space.EMPTY : right.getAfter()); + } + return right; + } + }.reduce(method, new ArrayList<>()).get(0); + } + + @SuppressWarnings("unused") // used in rewrite-spring / convenient for consumers + public static Predicate> isTrueArgument() { + return args -> args.size() == 1 && isTrue(args.get(0)); + } + + @SuppressWarnings("unused") // used in rewrite-spring / convenient for consumers + public static Predicate> isFalseArgument() { + return args -> args.size() == 1 && isFalse(args.get(0)); + } + + public static boolean isTrue(Expression expression) { + return isBoolean(expression, Boolean.TRUE); + } + + public static boolean isFalse(Expression expression) { + return isBoolean(expression, Boolean.FALSE); + } + + private static boolean isBoolean(Expression expression, Boolean b) { + if (expression instanceof J.Literal) { + return expression.getType() == JavaType.Primitive.Boolean && b.equals(((J.Literal) expression).getValue()); + } + return false; + } + + @Override + public J.Lambda visitLambda(J.Lambda lambda, ExecutionContext ctx) { + lambda = (J.Lambda) super.visitLambda(lambda, ctx); + J body = lambda.getBody(); + if (body instanceof J.MethodInvocation && ToBeRemoved.hasMarker(body)) { + Expression select = ((J.MethodInvocation) body).getSelect(); + List parameters = lambda.getParameters().getParameters(); + if (select instanceof J.Identifier && !parameters.isEmpty() && parameters.get(0) instanceof J.VariableDeclarations) { + J.VariableDeclarations declarations = (J.VariableDeclarations) parameters.get(0); + if (((J.Identifier) select).getSimpleName().equals(declarations.getVariables().get(0).getSimpleName())) { + return ToBeRemoved.withMarker(lambda); + } + } else if (select instanceof J.MethodInvocation) { + return lambda.withBody(select.withPrefix(body.getPrefix())); + } + } else if (body instanceof J.Block && ToBeRemoved.hasMarker(body)) { + return ToBeRemoved.withMarker(lambda.withBody(ToBeRemoved.removeMarker(body))); + } + return lambda; + } + + @Override + public J.Block visitBlock(J.Block block, ExecutionContext ctx) { + int statementsCount = block.getStatements().size(); + + block = (J.Block) super.visitBlock(block, ctx); + List statements = block.getStatements(); + if (!statements.isEmpty() && statements.stream().allMatch(ToBeRemoved::hasMarker)) { + return ToBeRemoved.withMarker(block.withStatements(Collections.emptyList())); + } + + if (statementsCount > 0 && statements.isEmpty()) { + return ToBeRemoved.withMarker(block.withStatements(Collections.emptyList())); + } + + if (statements.stream().anyMatch(ToBeRemoved::hasMarker)) { + //noinspection DataFlowIssue + return block.withStatements(statements.stream() + .filter(s -> !ToBeRemoved.hasMarker(s) || s instanceof J.MethodInvocation && ((J.MethodInvocation) s).getSelect() instanceof J.MethodInvocation) + .map(s -> s instanceof J.MethodInvocation && ToBeRemoved.hasMarker(s) ? ((J.MethodInvocation) s).getSelect().withPrefix(s.getPrefix()) : s) + .collect(Collectors.toList())); + } + return block; + } + + @Value + @With + static class ToBeRemoved implements Marker { + UUID id; + static J2 withMarker(J2 j) { + return j.withMarkers(j.getMarkers().addIfAbsent(new ToBeRemoved(randomId()))); + } + static J2 removeMarker(J2 j) { + return j.withMarkers(j.getMarkers().removeByType(ToBeRemoved.class)); + } + static boolean hasMarker(J j) { + return j.getMarkers().findFirst(ToBeRemoved.class).isPresent(); + } + } +} diff --git a/rewrite-java/src/test/java/org/openrewrite/java/RemoveMethodInvocationsVisitorTest.java b/rewrite-java/src/test/java/org/openrewrite/java/RemoveMethodInvocationsVisitorTest.java new file mode 100644 index 00000000000..c94f455eb16 --- /dev/null +++ b/rewrite-java/src/test/java/org/openrewrite/java/RemoveMethodInvocationsVisitorTest.java @@ -0,0 +1,495 @@ +/* + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java; + +import org.junit.jupiter.api.Test; +import org.junitpioneer.jupiter.ExpectedToFail; +import org.openrewrite.DocumentExample; +import org.openrewrite.Recipe; +import org.openrewrite.test.RewriteTest; + +import java.util.List; + +import static org.openrewrite.java.Assertions.java; +import static org.openrewrite.test.RewriteTest.toRecipe; + +@SuppressWarnings({"ResultOfMethodCallIgnored", "CodeBlock2Expr", "RedundantThrows", "Convert2MethodRef", "EmptyTryBlock", "CatchMayIgnoreException", "EmptyFinallyBlock", "StringBufferReplaceableByString", "UnnecessaryLocalVariable"}) +class RemoveMethodInvocationsVisitorTest implements RewriteTest { + + private Recipe createRemoveMethodsRecipe(String... methods) { + return toRecipe(() -> new RemoveMethodInvocationsVisitor(List.of(methods))); + } + + @DocumentExample + @Test + void removeFromEnd() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder toString()")) + , + //language=java + java( + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + sb.append("Hello") + .append(" ") + .append("World") + .reverse() + .append(" ") + .reverse() + .append("Yeah") + .toString(); + } + } + """, + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + sb.append("Hello") + .append(" ") + .append("World") + .reverse() + .append(" ") + .reverse() + .append("Yeah"); + } + } + """ + ) + ); + } + + @Test + void removeMultipleMethodsFromEnd() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder toString()", "java.lang.StringBuilder append(java.lang.String)")), + //language=java + java( + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + sb.append("Hello") + .append(" ") + .append("World") + .reverse() + .append(" ") + .reverse() + .append("Yeah") + .toString(); + } + } + """, + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + sb.reverse() + .reverse(); + } + } + """ + ) + ); + } + + @Test + void removeFromMiddle() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder reverse()")), + //language=java + java( + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + sb.append("Hello") + .append(" ") + .append("World") + .reverse() + .append(" ") + .reverse() + .append("Yeah") + .toString(); + } + } + """, + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + sb.append("Hello") + .append(" ") + .append("World") + .append(" ") + .append("Yeah") + .toString(); + } + } + """ + ) + ); + } + + @Test + void removeEntireStatement() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + //language=java + java( + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + sb.append("Hello"); + } + } + """, + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + } + } + """ + ) + ); + } + + @Test + @ExpectedToFail + void removeWithoutSelect() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("Test foo()")), + //language=java + java( + """ + public class Test { + void foo() {} + void method() { + StringBuilder sb = new StringBuilder(); + foo(); + } + } + """, + """ + public class Test { + void foo() {} + void method() { + StringBuilder sb = new StringBuilder(); + } + } + """ + ) + ); + } + + @Test + void removeFromWithinArguments() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + //language=java + java( + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + StringBuilder sb2 = new StringBuilder(); + sb.append(1) + .append(((java.util.function.Supplier) () -> sb2 + .append("foo") + .append('b') + .toString() + .charAt(0))) + .append(2) + .toString(); + } + } + """, + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + StringBuilder sb2 = new StringBuilder(); + sb.append(1) + .append(((java.util.function.Supplier) () -> sb2 + .append('b') + .toString() + .charAt(0))) + .append(2) + .toString(); + } + } + """ + ) + ); + } + + @Test + void keepSelectForAssignment() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + //language=java + java( + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + StringBuilder sb2 = sb.append("foo"); + sb2.append("bar"); + sb2.reverse(); + } + } + """, + """ + public class Test { + void method() { + StringBuilder sb = new StringBuilder(); + StringBuilder sb2 = sb; + sb2.reverse(); + } + } + """ + ) + ); + } + + @Test + void chainedCallsAsParameter() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + // language=java + java( + """ + class Test { + void method() { + print(new StringBuilder() + .append("Hello") + .append(" ") + .append("World") + .reverse() + .append(" ") + .reverse() + .append("Yeah") + .toString()); + } + void print(String str) {} + } + """, + """ + class Test { + void method() { + print(new StringBuilder() + .reverse() + .reverse() + .toString()); + } + void print(String str) {} + } + """ + ) + ); + } + + @Test + void removeFromLambda() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + // language=java + java( + """ + import java.util.List; + + public class Test { + void method(List names) { + names.forEach(name -> new StringBuilder() + .append("hello") + .append(" ") + .append(name) + .reverse() + .toString()); + } + } + """, + """ + import java.util.List; + + public class Test { + void method(List names) { + names.forEach(name -> new StringBuilder() + .reverse() + .toString()); + } + } + """ + ) + ); + } + + @Test + void complexCase() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + // language=java + java( + """ + import java.util.List; + import java.util.function.Consumer; + + public class Test { + void method(List names) { + this.consume(s -> names.forEach(name -> { + new StringBuilder() + .append("hello") + .append(" ") + .append(name) + .reverse() + .toString(); + } + ) + ).toString(); + } + StringBuilder consume(Consumer consumer) {return new StringBuilder();} + } + """, + """ + import java.util.List; + import java.util.function.Consumer; + + public class Test { + void method(List names) { + this.consume(s -> names.forEach(name -> { + new StringBuilder() + .reverse() + .toString(); + } + ) + ).toString(); + } + StringBuilder consume(Consumer consumer) {return new StringBuilder();} + } + """ + ) + ); + } + + @Test + void returnEmptyLambdaBody() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + // language=java + java( + """ + import java.util.function.Consumer; + + public class Test { + public void method() throws Exception { + this.customize(sb -> sb + .append("Hello") + ); + } + + public void customize(Consumer securityContextCustomizer) { + } + } + """, + """ + import java.util.function.Consumer; + + public class Test { + public void method() throws Exception { + } + + public void customize(Consumer securityContextCustomizer) { + } + } + """ + ) + ); + } + + @Test + void lambdaAssignment() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + // language=java + java( + """ + import java.util.function.Consumer; + + public class Test { + public void method() { + StringBuilder sb = new StringBuilder(); + Consumer consumer = name -> { + sb.append(name); + }; + consumer.accept("hello"); + } + } + """, + """ + import java.util.function.Consumer; + + public class Test { + public void method() { + StringBuilder sb = new StringBuilder(); + Consumer consumer = name -> { + }; + consumer.accept("hello"); + } + } + """ + ) + ); + } + + @Test + void tryCatchBlocks() { + rewriteRun( + spec -> spec.recipe(createRemoveMethodsRecipe("java.lang.StringBuilder append(java.lang.String)")), + // language=java + java( + """ + public class Test { + public void method() { + StringBuilder sb = new StringBuilder(); + try { + sb.append("Hello"); + } catch (Exception e) { + sb.append("Hello"); + } finally { + sb.append("Hello"); + } + } + } + """, + """ + public class Test { + public void method() { + StringBuilder sb = new StringBuilder(); + try { + } catch (Exception e) { + } finally { + } + } + } + """ + ) + ); + } +}