diff --git a/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java b/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java index 251fdea9c6..adc1a53fcf 100644 --- a/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java +++ b/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java @@ -10,6 +10,7 @@ import com.google.errorprone.refaster.Refaster; import com.google.errorprone.refaster.annotation.AfterTemplate; import com.google.errorprone.refaster.annotation.BeforeTemplate; +import com.google.errorprone.refaster.annotation.Matches; import com.google.errorprone.refaster.annotation.MayOptionallyUse; import com.google.errorprone.refaster.annotation.Placeholder; import com.google.errorprone.refaster.annotation.UseImportPolicy; @@ -19,10 +20,14 @@ import java.util.Optional; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.ToDoubleFunction; +import java.util.function.ToIntFunction; +import java.util.function.ToLongFunction; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.Stream; import tech.picnic.errorprone.refaster.annotation.OnlineDocumentation; +import tech.picnic.errorprone.refaster.matchers.IsLambdaExpressionOrMethodReference; /** Refaster rules related to expressions dealing with {@link Stream}s. */ @OnlineDocumentation @@ -379,4 +384,46 @@ boolean after(Stream stream) { return stream.allMatch(e -> test(e)); } } + + static final class StreamMapToIntSum { + @BeforeTemplate + int before( + Stream stream, + @Matches(IsLambdaExpressionOrMethodReference.class) Function mapper) { + return stream.map(mapper).reduce(0, Integer::sum); + } + + @AfterTemplate + int after(Stream stream, ToIntFunction mapper) { + return stream.mapToInt(mapper).sum(); + } + } + + static final class StreamMapToDoubleSum { + @BeforeTemplate + double before( + Stream stream, + @Matches(IsLambdaExpressionOrMethodReference.class) Function mapper) { + return stream.map(mapper).reduce(0.0, Double::sum); + } + + @AfterTemplate + double after(Stream stream, ToDoubleFunction mapper) { + return stream.mapToDouble(mapper).sum(); + } + } + + static final class StreamMapToLongSum { + @BeforeTemplate + long before( + Stream stream, + @Matches(IsLambdaExpressionOrMethodReference.class) Function mapper) { + return stream.map(mapper).reduce(0L, Long::sum); + } + + @AfterTemplate + long after(Stream stream, ToLongFunction mapper) { + return stream.mapToLong(mapper).sum(); + } + } } diff --git a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java index 0f291c343e..f11d04964a 100644 --- a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java +++ b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java @@ -9,6 +9,7 @@ import com.google.common.collect.Streams; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Stream; import tech.picnic.errorprone.refaster.test.RefasterRuleCollectionTestCase; @@ -138,4 +139,28 @@ ImmutableSet testStreamAllMatch() { boolean testStreamAllMatch2() { return Stream.of("foo").noneMatch(s -> !s.isBlank()); } + + ImmutableSet testStreamMapToIntSum() { + Function parseIntFunction = Integer::parseInt; + return ImmutableSet.of( + Stream.of(1).map(i -> i * 2).reduce(0, Integer::sum), + Stream.of("2").map(Integer::parseInt).reduce(0, Integer::sum), + Stream.of("3").map(parseIntFunction).reduce(0, Integer::sum)); + } + + ImmutableSet testStreamMapToDoubleSum() { + Function parseDoubleFunction = Double::parseDouble; + return ImmutableSet.of( + Stream.of(1).map(i -> i * 2.0).reduce(0.0, Double::sum), + Stream.of("2").map(Double::parseDouble).reduce(0.0, Double::sum), + Stream.of("3").map(parseDoubleFunction).reduce(0.0, Double::sum)); + } + + ImmutableSet testStreamMapToLongSum() { + Function parseLongFunction = Long::parseLong; + return ImmutableSet.of( + Stream.of(1).map(i -> i * 2L).reduce(0L, Long::sum), + Stream.of("2").map(Long::parseLong).reduce(0L, Long::sum), + Stream.of("3").map(parseLongFunction).reduce(0L, Long::sum)); + } } diff --git a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java index 19e7d9fdda..8c89199f0d 100644 --- a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java +++ b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java @@ -11,6 +11,7 @@ import java.util.Arrays; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Stream; import tech.picnic.errorprone.refaster.test.RefasterRuleCollectionTestCase; @@ -137,4 +138,28 @@ ImmutableSet testStreamAllMatch() { boolean testStreamAllMatch2() { return Stream.of("foo").allMatch(s -> s.isBlank()); } + + ImmutableSet testStreamMapToIntSum() { + Function parseIntFunction = Integer::parseInt; + return ImmutableSet.of( + Stream.of(1).mapToInt(i -> i * 2).sum(), + Stream.of("2").mapToInt(Integer::parseInt).sum(), + Stream.of("3").map(parseIntFunction).reduce(0, Integer::sum)); + } + + ImmutableSet testStreamMapToDoubleSum() { + Function parseDoubleFunction = Double::parseDouble; + return ImmutableSet.of( + Stream.of(1).mapToDouble(i -> i * 2.0).sum(), + Stream.of("2").mapToDouble(Double::parseDouble).sum(), + Stream.of("3").map(parseDoubleFunction).reduce(0.0, Double::sum)); + } + + ImmutableSet testStreamMapToLongSum() { + Function parseLongFunction = Long::parseLong; + return ImmutableSet.of( + Stream.of(1).mapToLong(i -> i * 2L).sum(), + Stream.of("2").mapToLong(Long::parseLong).sum(), + Stream.of("3").map(parseLongFunction).reduce(0L, Long::sum)); + } } diff --git a/refaster-support/src/main/java/tech/picnic/errorprone/refaster/matchers/IsLambdaExpressionOrMethodReference.java b/refaster-support/src/main/java/tech/picnic/errorprone/refaster/matchers/IsLambdaExpressionOrMethodReference.java new file mode 100644 index 0000000000..23721f73ae --- /dev/null +++ b/refaster-support/src/main/java/tech/picnic/errorprone/refaster/matchers/IsLambdaExpressionOrMethodReference.java @@ -0,0 +1,20 @@ +package tech.picnic.errorprone.refaster.matchers; + +import com.google.errorprone.VisitorState; +import com.google.errorprone.matchers.Matcher; +import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.LambdaExpressionTree; +import com.sun.source.tree.MemberReferenceTree; + +/** A matcher of lambda expressions and method references. */ +public final class IsLambdaExpressionOrMethodReference implements Matcher { + private static final long serialVersionUID = 1L; + + /** Instantiates a new {@link IsLambdaExpressionOrMethodReference} instance. */ + public IsLambdaExpressionOrMethodReference() {} + + @Override + public boolean matches(ExpressionTree tree, VisitorState state) { + return tree instanceof LambdaExpressionTree || tree instanceof MemberReferenceTree; + } +} diff --git a/refaster-support/src/test/java/tech/picnic/errorprone/refaster/matchers/IsLambdaExpressionOrMethodReferenceTest.java b/refaster-support/src/test/java/tech/picnic/errorprone/refaster/matchers/IsLambdaExpressionOrMethodReferenceTest.java new file mode 100644 index 0000000000..3280024cda --- /dev/null +++ b/refaster-support/src/test/java/tech/picnic/errorprone/refaster/matchers/IsLambdaExpressionOrMethodReferenceTest.java @@ -0,0 +1,72 @@ +package tech.picnic.errorprone.refaster.matchers; + +import static com.google.errorprone.BugPattern.SeverityLevel.ERROR; + +import com.google.errorprone.BugPattern; +import com.google.errorprone.CompilationTestHelper; +import com.google.errorprone.bugpatterns.BugChecker; +import org.junit.jupiter.api.Test; + +final class IsLambdaExpressionOrMethodReferenceTest { + @Test + void matches() { + CompilationTestHelper.newInstance(MatcherTestChecker.class, getClass()) + .addSourceLines( + "A.java", + "import com.google.common.base.Predicates;", + "import java.util.function.Function;", + "import java.util.function.Predicate;", + "", + "class A {", + " boolean negative1() {", + " return true;", + " }", + "", + " String negative2() {", + " return new String(new byte[0]);", + " }", + "", + " Predicate negative3() {", + " return Predicates.alwaysTrue();", + " }", + "", + " Predicate positive1() {", + " // BUG: Diagnostic contains:", + " return str -> true;", + " }", + "", + " Predicate positive2() {", + " // BUG: Diagnostic contains:", + " return str -> {", + " return true;", + " };", + " }", + "", + " Predicate positive3() {", + " // BUG: Diagnostic contains:", + " return String::isEmpty;", + " }", + "", + " Function positive4() {", + " // BUG: Diagnostic contains:", + " return String::new;", + " }", + "}") + .doTest(); + } + + /** A {@link BugChecker} that simply delegates to {@link IsLambdaExpressionOrMethodReference}. */ + @BugPattern( + summary = "Flags expressions matched by `IsLambdaExpressionOrMethodReference`", + severity = ERROR) + public static final class MatcherTestChecker extends AbstractMatcherTestChecker { + private static final long serialVersionUID = 1L; + + // XXX: This is a false positive reported by Checkstyle. See + // https://github.com/checkstyle/checkstyle/issues/10161#issuecomment-1242732120. + @SuppressWarnings("RedundantModifier") + public MatcherTestChecker() { + super(new IsLambdaExpressionOrMethodReference()); + } + } +}