diff --git a/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java b/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java index fb922ca499..c1b1180bcc 100644 --- a/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java +++ b/error-prone-contrib/src/main/java/tech/picnic/errorprone/refasterrules/StreamRules.java @@ -4,7 +4,20 @@ import static java.util.Comparator.naturalOrder; import static java.util.Comparator.reverseOrder; import static java.util.function.Predicate.not; +import static java.util.stream.Collectors.counting; +import static java.util.stream.Collectors.filtering; +import static java.util.stream.Collectors.flatMapping; import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.maxBy; +import static java.util.stream.Collectors.minBy; +import static java.util.stream.Collectors.reducing; +import static java.util.stream.Collectors.summarizingDouble; +import static java.util.stream.Collectors.summarizingInt; +import static java.util.stream.Collectors.summarizingLong; +import static java.util.stream.Collectors.summingDouble; +import static java.util.stream.Collectors.summingInt; +import static java.util.stream.Collectors.summingLong; import com.google.common.collect.Streams; import com.google.errorprone.refaster.Refaster; @@ -16,8 +29,12 @@ import com.google.errorprone.refaster.annotation.UseImportPolicy; import java.util.Arrays; import java.util.Comparator; +import java.util.DoubleSummaryStatistics; +import java.util.IntSummaryStatistics; +import java.util.LongSummaryStatistics; import java.util.Objects; import java.util.Optional; +import java.util.function.BinaryOperator; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.ToDoubleFunction; @@ -263,9 +280,12 @@ boolean after(Stream stream) { static final class StreamMin { @BeforeTemplate + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) Optional before(Stream stream, Comparator comparator) { return Refaster.anyOf( - stream.max(comparator.reversed()), stream.sorted(comparator).findFirst()); + stream.max(comparator.reversed()), + stream.sorted(comparator).findFirst(), + stream.collect(minBy(comparator))); } @AfterTemplate @@ -289,9 +309,12 @@ Optional after(Stream stream) { static final class StreamMax { @BeforeTemplate + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) Optional before(Stream stream, Comparator comparator) { return Refaster.anyOf( - stream.min(comparator.reversed()), Streams.findLast(stream.sorted(comparator))); + stream.min(comparator.reversed()), + Streams.findLast(stream.sorted(comparator)), + stream.collect(maxBy(comparator))); } @AfterTemplate @@ -389,7 +412,13 @@ boolean after(Stream stream) { static final class StreamMapToIntSum { @BeforeTemplate - int before( + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) + long before(Stream stream, ToIntFunction mapper) { + return stream.collect(summingInt(mapper)); + } + + @BeforeTemplate + int before2( Stream stream, @Matches(IsLambdaExpressionOrMethodReference.class) Function mapper) { return stream.map(mapper).reduce(0, Integer::sum); @@ -403,7 +432,13 @@ int after(Stream stream, ToIntFunction mapper) { static final class StreamMapToDoubleSum { @BeforeTemplate - double before( + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) + double before(Stream stream, ToDoubleFunction mapper) { + return stream.collect(summingDouble(mapper)); + } + + @BeforeTemplate + double before2( Stream stream, @Matches(IsLambdaExpressionOrMethodReference.class) Function mapper) { return stream.map(mapper).reduce(0.0, Double::sum); @@ -417,7 +452,13 @@ static final class StreamMapToDoubleSum { static final class StreamMapToLongSum { @BeforeTemplate - long before( + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) + long before(Stream stream, ToLongFunction mapper) { + return stream.collect(summingLong(mapper)); + } + + @BeforeTemplate + long before2( Stream stream, @Matches(IsLambdaExpressionOrMethodReference.class) Function mapper) { return stream.map(mapper).reduce(0L, Long::sum); @@ -428,4 +469,130 @@ long after(Stream stream, ToLongFunction mapper) { return stream.mapToLong(mapper).sum(); } } + + static final class StreamMapToIntSummaryStatistics { + @BeforeTemplate + IntSummaryStatistics before(Stream stream, ToIntFunction mapper) { + return stream.collect(summarizingInt(mapper)); + } + + @AfterTemplate + IntSummaryStatistics after(Stream stream, ToIntFunction mapper) { + return stream.mapToInt(mapper).summaryStatistics(); + } + } + + static final class StreamMapToDoubleSummaryStatistics { + @BeforeTemplate + DoubleSummaryStatistics before(Stream stream, ToDoubleFunction mapper) { + return stream.collect(summarizingDouble(mapper)); + } + + @AfterTemplate + DoubleSummaryStatistics after(Stream stream, ToDoubleFunction mapper) { + return stream.mapToDouble(mapper).summaryStatistics(); + } + } + + static final class StreamMapToLongSummaryStatistics { + @BeforeTemplate + LongSummaryStatistics before(Stream stream, ToLongFunction mapper) { + return stream.collect(summarizingLong(mapper)); + } + + @AfterTemplate + LongSummaryStatistics after(Stream stream, ToLongFunction mapper) { + return stream.mapToLong(mapper).summaryStatistics(); + } + } + + static final class StreamCount { + @BeforeTemplate + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) + long before(Stream stream) { + return stream.collect(counting()); + } + + @AfterTemplate + long after(Stream stream) { + return stream.count(); + } + } + + static final class StreamReduce { + @BeforeTemplate + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) + Optional before(Stream stream, BinaryOperator accumulator) { + return stream.collect(reducing(accumulator)); + } + + @AfterTemplate + Optional after(Stream stream, BinaryOperator accumulator) { + return stream.reduce(accumulator); + } + } + + static final class StreamReduceWithIdentity { + @BeforeTemplate + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) + T before(Stream stream, T identity, BinaryOperator accumulator) { + return stream.collect(reducing(identity, accumulator)); + } + + @AfterTemplate + T after(Stream stream, T identity, BinaryOperator accumulator) { + return stream.reduce(identity, accumulator); + } + } + + static final class StreamFilterCollect { + @BeforeTemplate + R before( + Stream stream, Predicate predicate, Collector collector) { + return stream.collect(filtering(predicate, collector)); + } + + @AfterTemplate + R after( + Stream stream, Predicate predicate, Collector collector) { + return stream.filter(predicate).collect(collector); + } + } + + static final class StreamMapCollect { + @BeforeTemplate + @SuppressWarnings("java:S4266" /* This violation will be rewritten. */) + R before( + Stream stream, + Function mapper, + Collector collector) { + return stream.collect(mapping(mapper, collector)); + } + + @AfterTemplate + R after( + Stream stream, + Function mapper, + Collector collector) { + return stream.map(mapper).collect(collector); + } + } + + static final class StreamFlatMapCollect { + @BeforeTemplate + R before( + Stream stream, + Function> mapper, + Collector collector) { + return stream.collect(flatMapping(mapper, collector)); + } + + @AfterTemplate + R after( + Stream stream, + Function> mapper, + Collector collector) { + return stream.flatMap(mapper).collect(collector); + } + } } diff --git a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java index 1ef30d1e3d..c568e49507 100644 --- a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java +++ b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestInput.java @@ -1,12 +1,29 @@ package tech.picnic.errorprone.refasterrules; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Comparator.comparingInt; import static java.util.Comparator.reverseOrder; import static java.util.function.Predicate.not; +import static java.util.stream.Collectors.counting; +import static java.util.stream.Collectors.filtering; +import static java.util.stream.Collectors.flatMapping; import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.maxBy; +import static java.util.stream.Collectors.minBy; +import static java.util.stream.Collectors.reducing; +import static java.util.stream.Collectors.summarizingDouble; +import static java.util.stream.Collectors.summarizingInt; +import static java.util.stream.Collectors.summarizingLong; +import static java.util.stream.Collectors.summingDouble; +import static java.util.stream.Collectors.summingInt; +import static java.util.stream.Collectors.summingLong; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; +import java.util.DoubleSummaryStatistics; +import java.util.IntSummaryStatistics; +import java.util.LongSummaryStatistics; import java.util.Objects; import java.util.Optional; import java.util.function.Function; @@ -17,7 +34,23 @@ final class StreamRulesTest implements RefasterRuleCollectionTestCase { @Override public ImmutableSet elidedTypesAndStaticImports() { - return ImmutableSet.of(Objects.class, Streams.class, not(null)); + return ImmutableSet.of( + Objects.class, + Streams.class, + counting(), + filtering(null, null), + flatMapping(null, null), + mapping(null, null), + maxBy(null), + minBy(null), + not(null), + reducing(null), + summarizingDouble(null), + summarizingInt(null), + summarizingLong(null), + summingDouble(null), + summingInt(null), + summingLong(null)); } String testJoining() { @@ -90,7 +123,8 @@ ImmutableSet testStreamIsNotEmpty() { ImmutableSet> testStreamMin() { return ImmutableSet.of( Stream.of("foo").max(comparingInt(String::length).reversed()), - Stream.of("bar").sorted(comparingInt(String::length)).findFirst()); + Stream.of("bar").sorted(comparingInt(String::length)).findFirst(), + Stream.of("baz").collect(minBy(comparingInt(String::length)))); } ImmutableSet> testStreamMinNaturalOrder() { @@ -101,7 +135,8 @@ ImmutableSet> testStreamMinNaturalOrder() { ImmutableSet> testStreamMax() { return ImmutableSet.of( Stream.of("foo").min(comparingInt(String::length).reversed()), - Streams.findLast(Stream.of("bar").sorted(comparingInt(String::length)))); + Streams.findLast(Stream.of("bar").sorted(comparingInt(String::length))), + Stream.of("baz").collect(maxBy(comparingInt(String::length)))); } ImmutableSet> testStreamMaxNaturalOrder() { @@ -143,24 +178,63 @@ boolean testStreamAllMatch2() { ImmutableSet testStreamMapToIntSum() { Function parseIntFunction = Integer::parseInt; return ImmutableSet.of( - Stream.of(1).map(i -> i * 2).reduce(0, Integer::sum), - Stream.of("2").map(Integer::parseInt).reduce(0, Integer::sum), - Stream.of("3").map(parseIntFunction).reduce(0, Integer::sum)); + Stream.of("1").collect(summingInt(Integer::parseInt)), + Stream.of(2).map(i -> i * 2).reduce(0, Integer::sum), + Stream.of("3").map(Integer::parseInt).reduce(0, Integer::sum), + Stream.of("4").map(parseIntFunction).reduce(0, Integer::sum)); } ImmutableSet testStreamMapToDoubleSum() { Function parseDoubleFunction = Double::parseDouble; return ImmutableSet.of( - Stream.of(1).map(i -> i * 2.0).reduce(0.0, Double::sum), - Stream.of("2").map(Double::parseDouble).reduce(0.0, Double::sum), - Stream.of("3").map(parseDoubleFunction).reduce(0.0, Double::sum)); + Stream.of("1").collect(summingDouble(Double::parseDouble)), + Stream.of(2).map(i -> i * 2.0).reduce(0.0, Double::sum), + Stream.of("3").map(Double::parseDouble).reduce(0.0, Double::sum), + Stream.of("4").map(parseDoubleFunction).reduce(0.0, Double::sum)); } ImmutableSet testStreamMapToLongSum() { Function parseLongFunction = Long::parseLong; return ImmutableSet.of( - Stream.of(1).map(i -> i * 2L).reduce(0L, Long::sum), - Stream.of("2").map(Long::parseLong).reduce(0L, Long::sum), - Stream.of("3").map(parseLongFunction).reduce(0L, Long::sum)); + Stream.of("1").collect(summingLong(Long::parseLong)), + Stream.of(2).map(i -> i * 2L).reduce(0L, Long::sum), + Stream.of("3").map(Long::parseLong).reduce(0L, Long::sum), + Stream.of("4").map(parseLongFunction).reduce(0L, Long::sum)); + } + + IntSummaryStatistics testStreamMapToIntSummaryStatistics() { + return Stream.of("1").collect(summarizingInt(Integer::parseInt)); + } + + DoubleSummaryStatistics testStreamMapToDoubleSummaryStatistics() { + return Stream.of("1").collect(summarizingDouble(Double::parseDouble)); + } + + LongSummaryStatistics testStreamMapToLongSummaryStatistics() { + return Stream.of("1").collect(summarizingLong(Long::parseLong)); + } + + Long testStreamCount() { + return Stream.of(1).collect(counting()); + } + + Optional testStreamReduce() { + return Stream.of(1).collect(reducing(Integer::sum)); + } + + Integer testStreamReduceWithIdentity() { + return Stream.of(1).collect(reducing(0, Integer::sum)); + } + + ImmutableSet testStreamFilterCollect() { + return Stream.of("1").collect(filtering(String::isEmpty, toImmutableSet())); + } + + ImmutableSet testStreamMapCollect() { + return Stream.of("1").collect(mapping(Integer::parseInt, toImmutableSet())); + } + + ImmutableSet testStreamFlatMapCollect() { + return Stream.of(1).collect(flatMapping(n -> Stream.of(n, n), toImmutableSet())); } } diff --git a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java index 549df57117..fb7484f6b5 100644 --- a/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java +++ b/error-prone-contrib/src/test/resources/tech/picnic/errorprone/refasterrules/StreamRulesTestOutput.java @@ -1,14 +1,31 @@ package tech.picnic.errorprone.refasterrules; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Comparator.comparingInt; import static java.util.Comparator.naturalOrder; import static java.util.Comparator.reverseOrder; import static java.util.function.Predicate.not; +import static java.util.stream.Collectors.counting; +import static java.util.stream.Collectors.filtering; +import static java.util.stream.Collectors.flatMapping; import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.maxBy; +import static java.util.stream.Collectors.minBy; +import static java.util.stream.Collectors.reducing; +import static java.util.stream.Collectors.summarizingDouble; +import static java.util.stream.Collectors.summarizingInt; +import static java.util.stream.Collectors.summarizingLong; +import static java.util.stream.Collectors.summingDouble; +import static java.util.stream.Collectors.summingInt; +import static java.util.stream.Collectors.summingLong; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import java.util.Arrays; +import java.util.DoubleSummaryStatistics; +import java.util.IntSummaryStatistics; +import java.util.LongSummaryStatistics; import java.util.Objects; import java.util.Optional; import java.util.function.Function; @@ -19,7 +36,23 @@ final class StreamRulesTest implements RefasterRuleCollectionTestCase { @Override public ImmutableSet elidedTypesAndStaticImports() { - return ImmutableSet.of(Objects.class, Streams.class, not(null)); + return ImmutableSet.of( + Objects.class, + Streams.class, + counting(), + filtering(null, null), + flatMapping(null, null), + mapping(null, null), + maxBy(null), + minBy(null), + not(null), + reducing(null), + summarizingDouble(null), + summarizingInt(null), + summarizingLong(null), + summingDouble(null), + summingInt(null), + summingLong(null)); } String testJoining() { @@ -91,7 +124,8 @@ ImmutableSet testStreamIsNotEmpty() { ImmutableSet> testStreamMin() { return ImmutableSet.of( Stream.of("foo").min(comparingInt(String::length)), - Stream.of("bar").min(comparingInt(String::length))); + Stream.of("bar").min(comparingInt(String::length)), + Stream.of("baz").min(comparingInt(String::length))); } ImmutableSet> testStreamMinNaturalOrder() { @@ -102,7 +136,8 @@ ImmutableSet> testStreamMinNaturalOrder() { ImmutableSet> testStreamMax() { return ImmutableSet.of( Stream.of("foo").max(comparingInt(String::length)), - Stream.of("bar").max(comparingInt(String::length))); + Stream.of("bar").max(comparingInt(String::length)), + Stream.of("baz").max(comparingInt(String::length))); } ImmutableSet> testStreamMaxNaturalOrder() { @@ -142,24 +177,63 @@ boolean testStreamAllMatch2() { ImmutableSet testStreamMapToIntSum() { Function parseIntFunction = Integer::parseInt; return ImmutableSet.of( - Stream.of(1).mapToInt(i -> i * 2).sum(), - Stream.of("2").mapToInt(Integer::parseInt).sum(), - Stream.of("3").map(parseIntFunction).reduce(0, Integer::sum)); + Stream.of("1").mapToInt(Integer::parseInt).sum(), + Stream.of(2).mapToInt(i -> i * 2).sum(), + Stream.of("3").mapToInt(Integer::parseInt).sum(), + Stream.of("4").map(parseIntFunction).reduce(0, Integer::sum)); } ImmutableSet testStreamMapToDoubleSum() { Function parseDoubleFunction = Double::parseDouble; return ImmutableSet.of( - Stream.of(1).mapToDouble(i -> i * 2.0).sum(), - Stream.of("2").mapToDouble(Double::parseDouble).sum(), - Stream.of("3").map(parseDoubleFunction).reduce(0.0, Double::sum)); + Stream.of("1").mapToDouble(Double::parseDouble).sum(), + Stream.of(2).mapToDouble(i -> i * 2.0).sum(), + Stream.of("3").mapToDouble(Double::parseDouble).sum(), + Stream.of("4").map(parseDoubleFunction).reduce(0.0, Double::sum)); } ImmutableSet testStreamMapToLongSum() { Function parseLongFunction = Long::parseLong; return ImmutableSet.of( - Stream.of(1).mapToLong(i -> i * 2L).sum(), - Stream.of("2").mapToLong(Long::parseLong).sum(), - Stream.of("3").map(parseLongFunction).reduce(0L, Long::sum)); + Stream.of("1").mapToLong(Long::parseLong).sum(), + Stream.of(2).mapToLong(i -> i * 2L).sum(), + Stream.of("3").mapToLong(Long::parseLong).sum(), + Stream.of("4").map(parseLongFunction).reduce(0L, Long::sum)); + } + + IntSummaryStatistics testStreamMapToIntSummaryStatistics() { + return Stream.of("1").mapToInt(Integer::parseInt).summaryStatistics(); + } + + DoubleSummaryStatistics testStreamMapToDoubleSummaryStatistics() { + return Stream.of("1").mapToDouble(Double::parseDouble).summaryStatistics(); + } + + LongSummaryStatistics testStreamMapToLongSummaryStatistics() { + return Stream.of("1").mapToLong(Long::parseLong).summaryStatistics(); + } + + Long testStreamCount() { + return Stream.of(1).count(); + } + + Optional testStreamReduce() { + return Stream.of(1).reduce(Integer::sum); + } + + Integer testStreamReduceWithIdentity() { + return Stream.of(1).reduce(0, Integer::sum); + } + + ImmutableSet testStreamFilterCollect() { + return Stream.of("1").filter(String::isEmpty).collect(toImmutableSet()); + } + + ImmutableSet testStreamMapCollect() { + return Stream.of("1").map(Integer::parseInt).collect(toImmutableSet()); + } + + ImmutableSet testStreamFlatMapCollect() { + return Stream.of(1).flatMap(n -> Stream.of(n, n)).collect(toImmutableSet()); } }