From 2f696a96f30bb361d3a6e92d9e024a74a1993a65 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 30 Jul 2021 13:19:04 -0500 Subject: [PATCH 1/7] Refactor rolling aggregations to make it cleaner --- .../main/java/ai/rapids/cudf/Aggregation.java | 71 ++----- .../ai/rapids/cudf/AggregationOnColumn.java | 36 +--- .../ai/rapids/cudf/AggregationOverWindow.java | 43 ++-- .../main/java/ai/rapids/cudf/ColumnView.java | 5 +- .../ai/rapids/cudf/RollingAggregation.java | 185 +++++++++++++++++- .../cudf/RollingAggregationOnColumn.java | 65 ++++++ java/src/main/java/ai/rapids/cudf/Table.java | 4 +- .../java/ai/rapids/cudf/ColumnVectorTest.java | 26 +-- .../test/java/ai/rapids/cudf/TableTest.java | 182 ++++++++--------- 9 files changed, 400 insertions(+), 217 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 49c6d2b6ffc..62ca27c732e 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -65,7 +65,7 @@ enum Kind { M2(26), MERGE_M2(27), RANK(28), - DENSE_RANK(29);; + DENSE_RANK(29); final int nativeId; @@ -275,8 +275,7 @@ long getDefaultOutput() { } } - public static final class CollectListAggregation extends Aggregation - implements RollingAggregation { + public static final class CollectListAggregation extends Aggregation { private final NullPolicy nullPolicy; private CollectListAggregation(NullPolicy nullPolicy) { @@ -306,8 +305,7 @@ public boolean equals(Object other) { } } - public static final class CollectSetAggregation extends Aggregation - implements RollingAggregation { + public static final class CollectSetAggregation extends Aggregation { private final NullPolicy nullPolicy; private final NullEquality nullEquality; private final NaNEquality nanEquality; @@ -392,8 +390,8 @@ protected Aggregation(Kind kind) { * Add a column to the Aggregation so it can be used on a specific column of data. * @param columnIndex the index of the column to operate on. */ - public AggregationOnColumn onColumn(int columnIndex) { - return new AggregationOnColumn((T)this, columnIndex); + public AggregationOnColumn onColumn(int columnIndex) { + return new AggregationOnColumn(this, columnIndex); } /** @@ -433,8 +431,7 @@ static void close(long[] ptrs) { static native void close(long ptr); - public static class SumAggregation extends NoParamAggregation - implements RollingAggregation { + public static class SumAggregation extends NoParamAggregation { private SumAggregation() { super(Kind.SUM); } @@ -460,8 +457,7 @@ public static ProductAggregation product() { return new ProductAggregation(); } - public static class MinAggregation extends NoParamAggregation - implements RollingAggregation { + public static class MinAggregation extends NoParamAggregation { private MinAggregation() { super(Kind.MIN); } @@ -474,8 +470,7 @@ public static MinAggregation min() { return new MinAggregation(); } - public static class MaxAggregation extends NoParamAggregation - implements RollingAggregation { + public static class MaxAggregation extends NoParamAggregation { private MaxAggregation() { super(Kind.MAX); } @@ -488,8 +483,7 @@ public static MaxAggregation max() { return new MaxAggregation(); } - public static class CountAggregation extends CountLikeAggregation - implements RollingAggregation { + public static class CountAggregation extends CountLikeAggregation { private CountAggregation(NullPolicy nullPolicy) { super(Kind.COUNT, nullPolicy); } @@ -555,8 +549,7 @@ public static SumOfSquaresAggregation sumOfSquares() { return new SumOfSquaresAggregation(); } - public static class MeanAggregation extends NoParamAggregation - implements RollingAggregation{ + public static class MeanAggregation extends NoParamAggregation { private MeanAggregation() { super(Kind.MEAN); } @@ -654,8 +647,7 @@ public static QuantileAggregation quantile(QuantileMethod method, double ... qua return new QuantileAggregation(method, quantiles); } - public static class ArgMaxAggregation extends NoParamAggregation - implements RollingAggregation{ + public static class ArgMaxAggregation extends NoParamAggregation { private ArgMaxAggregation() { super(Kind.ARGMAX); } @@ -671,8 +663,7 @@ public static ArgMaxAggregation argMax() { return new ArgMaxAggregation(); } - public static class ArgMinAggregation extends NoParamAggregation - implements RollingAggregation{ + public static class ArgMinAggregation extends NoParamAggregation { private ArgMinAggregation() { super(Kind.ARGMIN); } @@ -731,8 +722,7 @@ public static NthAggregation nth(int offset, NullPolicy nullPolicy) { return new NthAggregation(offset, nullPolicy); } - public static class RowNumberAggregation extends NoParamAggregation - implements RollingAggregation{ + static class RowNumberAggregation extends NoParamAggregation { private RowNumberAggregation() { super(Kind.ROW_NUMBER); } @@ -741,12 +731,11 @@ private RowNumberAggregation() { /** * Get the row number, only makes sense for a window operations. */ - public static RowNumberAggregation rowNumber() { + static RowNumberAggregation rowNumber() { return new RowNumberAggregation(); } - public static class RankAggregation extends NoParamAggregation - implements RollingAggregation{ + public static class RankAggregation extends NoParamAggregation { private RankAggregation() { super(Kind.RANK); } @@ -759,8 +748,7 @@ public static RankAggregation rank() { return new RankAggregation(); } - public static class DenseRankAggregation extends NoParamAggregation - implements RollingAggregation{ + public static class DenseRankAggregation extends NoParamAggregation { private DenseRankAggregation() { super(Kind.DENSE_RANK); } @@ -840,54 +828,35 @@ public static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEqual return new MergeSetsAggregation(nullEquality, nanEquality); } - public static class LeadAggregation extends LeadLagAggregation - implements RollingAggregation { + static class LeadAggregation extends LeadLagAggregation { private LeadAggregation(int offset, ColumnVector defaultOutput) { super(Kind.LEAD, offset, defaultOutput); } } - /** - * In a rolling window return the value offset entries ahead or null if it is outside of the - * window. - */ - public static LeadAggregation lead(int offset) { - return lead(offset, null); - } - /** * In a rolling window return the value offset entries ahead or the corresponding value from * defaultOutput if it is outside of the window. Note that this does not take any ownership of * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life * time of this aggregation operation. */ - public static LeadAggregation lead(int offset, ColumnVector defaultOutput) { + static LeadAggregation lead(int offset, ColumnVector defaultOutput) { return new LeadAggregation(offset, defaultOutput); } - public static class LagAggregation extends LeadLagAggregation - implements RollingAggregation{ + static class LagAggregation extends LeadLagAggregation { private LagAggregation(int offset, ColumnVector defaultOutput) { super(Kind.LAG, offset, defaultOutput); } } - - /** - * In a rolling window return the value offset entries behind or null if it is outside of the - * window. - */ - public static LagAggregation lag(int offset) { - return lag(offset, null); - } - /** * In a rolling window return the value offset entries behind or the corresponding value from * defaultOutput if it is outside of the window. Note that this does not take any ownership of * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life * time of this aggregation operation. */ - public static LagAggregation lag(int offset, ColumnVector defaultOutput) { + static LagAggregation lag(int offset, ColumnVector defaultOutput) { return new LagAggregation(offset, defaultOutput); } diff --git a/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java index bb1404e5a07..2d9364b9705 100644 --- a/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java +++ b/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java @@ -19,47 +19,23 @@ package ai.rapids.cudf; /** - * An Aggregation instance that also holds a column number so the aggregation can be done on - * a specific column of data in a table. + * An Aggregation for a specific column in a table. */ -public class AggregationOnColumn extends Aggregation { - protected final T wrapped; +public class AggregationOnColumn { + protected final Aggregation wrapped; protected final int columnIndex; - AggregationOnColumn(T wrapped, int columnIndex) { - super(wrapped.kind); + AggregationOnColumn(Aggregation wrapped, int columnIndex) { this.wrapped = wrapped; this.columnIndex = columnIndex; } - @Override - public AggregationOnColumn onColumn(int columnIndex) { - if (columnIndex == getColumnIndex()) { - return this; // NOOP - } else { - return new AggregationOnColumn(this.wrapped, columnIndex); - } - } - - /** - * Do the aggregation over a given Window. - */ - public > AggregationOverWindow overWindow(WindowOptions windowOptions) { - return new AggregationOverWindow(wrapped, columnIndex, windowOptions); - } - public int getColumnIndex() { return columnIndex; } - @Override - long createNativeInstance() { - return wrapped.createNativeInstance(); - } - - @Override - long getDefaultOutput() { - return wrapped.getDefaultOutput(); + Aggregation getWrapped() { + return wrapped; } @Override diff --git a/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java b/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java index abce287c9b0..9a82eae65bf 100644 --- a/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java +++ b/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java @@ -22,12 +22,12 @@ * An Aggregation instance that also holds a column number and window metadata so the aggregation * can be done over a specific window. */ -public class AggregationOverWindow> - extends AggregationOnColumn { +public class AggregationOverWindow { + private final RollingAggregationOnColumn wrapped; protected final WindowOptions windowOptions; - AggregationOverWindow(T wrapped, int columnIndex, WindowOptions windowOptions) { - super(wrapped, columnIndex); + AggregationOverWindow(RollingAggregationOnColumn wrapped, WindowOptions windowOptions) { + this.wrapped = wrapped; this.windowOptions = windowOptions; if (windowOptions == null) { @@ -43,23 +43,6 @@ public WindowOptions getWindowOptions() { return windowOptions; } - @Override - public AggregationOnColumn onColumn(int columnIndex) { - if (columnIndex == getColumnIndex()) { - return this; // NOOP - } else { - return new AggregationOverWindow(this.wrapped, columnIndex, windowOptions); - } - } - - @Override - public AggregationOverWindow overWindow(WindowOptions windowOptions) { - if (this.windowOptions.equals(windowOptions)) { - return this; - } - return new AggregationOverWindow(wrapped, columnIndex, windowOptions); - } - @Override public int hashCode() { return 31 * super.hashCode() + windowOptions.hashCode(); @@ -69,10 +52,22 @@ public int hashCode() { public boolean equals(Object other) { if (other == this) { return true; - } else if (other instanceof AggregationOnColumn) { - AggregationOnColumn o = (AggregationOnColumn) other; - return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; + } else if (other instanceof AggregationOverWindow) { + AggregationOverWindow o = (AggregationOverWindow) other; + return wrapped.equals(o.wrapped) && windowOptions.equals(o.windowOptions); } return false; } + + int getColumnIndex() { + return wrapped.getColumnIndex(); + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } } diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 4a1ed3a178e..1e2790d42c5 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1390,20 +1390,19 @@ public final ColumnVector quantile(QuantileMethod method, double[] quantiles) { * @throws IllegalArgumentException if unsupported window specification * (i.e. other than {@link WindowOptions.FrameType#ROWS} is used. */ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions options) { - Aggregation agg = op.getBaseAggregation(); // Check that only row-based windows are used. if (!options.getFrameType().equals(WindowOptions.FrameType.ROWS)) { throw new IllegalArgumentException("Expected ROWS-based window specification. Unexpected window type: " + options.getFrameType()); } - long nativePtr = agg.createNativeInstance(); + long nativePtr = op.createNativeInstance(); try { Scalar p = options.getPrecedingScalar(); Scalar f = options.getFollowingScalar(); return new ColumnVector( rollingWindow(this.getNativeView(), - agg.getDefaultOutput(), + op.getDefaultOutput(), options.getMinPeriods(), nativePtr, p == null || !p.isValid() ? 0 : p.getInt(), diff --git a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java index 9b80924463a..d9f026d79c5 100644 --- a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java @@ -22,8 +22,187 @@ * Used to tag an aggregation as something that is compatible with rolling window operations. * Do not try to implement this yourself */ -public interface RollingAggregation { - default T getBaseAggregation() { - return (T)this; +public class RollingAggregation { + private final Aggregation wrapped; + + private RollingAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + /** + * Add a column to the Aggregation so it can be used on a specific column of data. + * @param columnIndex the index of the column to operate on. + */ + public RollingAggregationOnColumn onColumn(int columnIndex) { + return new RollingAggregationOnColumn(this, columnIndex); + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof RollingAggregation) { + RollingAggregation o = (RollingAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Rolling Window Sum + */ + public static RollingAggregation sum() { + return new RollingAggregation(Aggregation.sum()); + } + + + /** + * Rolling Window Min + */ + public static RollingAggregation min() { + return new RollingAggregation(Aggregation.min()); + } + + /** + * Rolling Window Max + */ + public static RollingAggregation max() { + return new RollingAggregation(Aggregation.max()); + } + + + /** + * Count number of valid, a.k.a. non-null, elements. + */ + public static RollingAggregation count() { + return count(NullPolicy.EXCLUDE); + } + + /** + * Count number of elements. + * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values + * should be counted. + */ + public static RollingAggregation count(NullPolicy nullPolicy) { + return new RollingAggregation(Aggregation.count(nullPolicy)); + } + + /** + * Arithmetic Mean + */ + public static RollingAggregation mean() { + return new RollingAggregation(Aggregation.mean()); + } + + + /** + * Index of max element. + */ + public static RollingAggregation argMax() { + return new RollingAggregation(Aggregation.argMax()); + } + + /** + * Index of min element. + */ + public static RollingAggregation argMin() { + return new RollingAggregation(Aggregation.argMin()); + } + + + /** + * Get the row number. + */ + public static RollingAggregation rowNumber() { + return new RollingAggregation(Aggregation.rowNumber()); + } + + + /** + * In a rolling window return the value offset entries ahead or null if it is outside of the + * window. + */ + public static RollingAggregation lead(int offset) { + return lead(offset, null); + } + + /** + * In a rolling window return the value offset entries ahead or the corresponding value from + * defaultOutput if it is outside of the window. Note that this does not take any ownership of + * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life + * time of this aggregation operation. + */ + public static RollingAggregation lead(int offset, ColumnVector defaultOutput) { + return new RollingAggregation(Aggregation.lead(offset, defaultOutput)); + } + + + + /** + * In a rolling window return the value offset entries behind or null if it is outside of the + * window. + */ + public static RollingAggregation lag(int offset) { + return lag(offset, null); + } + + /** + * In a rolling window return the value offset entries behind or the corresponding value from + * defaultOutput if it is outside of the window. Note that this does not take any ownership of + * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life + * time of this aggregation operation. + */ + public static RollingAggregation lag(int offset, ColumnVector defaultOutput) { + return new RollingAggregation(Aggregation.lag(offset, defaultOutput)); + } + + + /** + * Collect the values into a list. Nulls will be skipped. + */ + public static RollingAggregation collectList() { + return new RollingAggregation(Aggregation.collectList()); + } + + /** + * Collect the values into a list. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + */ + public static RollingAggregation collectList(NullPolicy nullPolicy) { + return new RollingAggregation(Aggregation.collectList(nullPolicy)); + } + + + /** + * Collect the values into a set. All null values will be excluded, and all nan values are regarded as + * unique instances. + */ + public static RollingAggregation collectSet() { + return new RollingAggregation(Aggregation.collectSet()); + } + + /** + * Collect the values into a set. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static RollingAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + return new RollingAggregation(Aggregation.collectSet(nullPolicy, nullEquality, nanEquality)); } } diff --git a/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java new file mode 100644 index 00000000000..7fde7c30b3f --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java @@ -0,0 +1,65 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * A RollingAggregation for a specific column in a table. + */ +public class RollingAggregationOnColumn { + protected final RollingAggregation wrapped; + protected final int columnIndex; + + RollingAggregationOnColumn(RollingAggregation wrapped, int columnIndex) { + this.wrapped = wrapped; + this.columnIndex = columnIndex; + } + + public int getColumnIndex() { + return columnIndex; + } + + + public AggregationOverWindow overWindow(WindowOptions windowOptions) { + return new AggregationOverWindow(this, windowOptions); + } + + @Override + public int hashCode() { + return 31 * wrapped.hashCode() + columnIndex; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof RollingAggregationOnColumn) { + RollingAggregationOnColumn o = (RollingAggregationOnColumn) other; + return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; + } + return false; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 96a9b608f06..887c22e388a 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2471,7 +2471,7 @@ public Table aggregate(AggregationOnColumn... aggregates) { for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) { AggregationOnColumn agg = aggregates[outputIndex]; ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps()); - totalOps += ops.add(agg, outputIndex + keysLength); + totalOps += ops.add(agg.getWrapped(), outputIndex + keysLength); } int[] aggColumnIndexes = new int[totalOps]; long[] aggOperationInstances = new long[totalOps]; @@ -2823,7 +2823,7 @@ public Table scan(AggregationOnColumn... aggregates) { for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) { AggregationOnColumn agg = aggregates[outputIndex]; ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps()); - totalOps += ops.add(agg, outputIndex + keysLength); + totalOps += ops.add(agg.getWrapped(), outputIndex + keysLength); } int[] aggColumnIndexes = new int[totalOps]; long[] aggOperationInstances = new long[totalOps]; diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index d3fdb0e19bb..59efa14b3f9 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3058,39 +3058,39 @@ void testWindowStatic() { .minPeriods(2).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8)) { try (ColumnVector expected = ColumnVector.fromLongs(9, 16, 17, 21, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.sum(), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 6, 6); - ColumnVector result = v1.rollingWindow(Aggregation.min(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.min(), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromInts(5, 7, 7, 8, 8); - ColumnVector result = v1.rollingWindow(Aggregation.max(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.max(), options)) { assertColumnsAreEqual(expected, result); } // The rolling window produces the same result type as the input try (ColumnVector expected = ColumnVector.fromDoubles(4.5, 16.0 / 3, 17.0 / 3, 7, 7); - ColumnVector result = v1.rollingWindow(Aggregation.mean(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.mean(), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromBoxedInts(4, 7, 6, 8, null); - ColumnVector result = v1.rollingWindow(Aggregation.lead(1), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.lead(1), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.lag(1), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector defaultOutput = ColumnVector.fromInts(-1, -2, -3, -4, -5); ColumnVector expected = ColumnVector.fromBoxedInts(-1, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1, defaultOutput), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.lag(1, defaultOutput), options)) { assertColumnsAreEqual(expected, result); } } @@ -3106,11 +3106,11 @@ void testWindowStaticCounts() { .minPeriods(2).build()) { try (ColumnVector v1 = ColumnVector.fromBoxedInts(5, 4, null, 6, 8)) { try (ColumnVector expected = ColumnVector.fromInts(2, 2, 2, 2, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.EXCLUDE), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.count(NullPolicy.EXCLUDE), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromInts(2, 3, 3, 3, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.INCLUDE), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.count(NullPolicy.INCLUDE), options)) { assertColumnsAreEqual(expected, result); } } @@ -3125,7 +3125,7 @@ void testWindowDynamicNegative() { .minPeriods(2).window(precedingCol, followingCol).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); ColumnVector expected = ColumnVector.fromBoxedLongs(null, null, 9L, 16L, 25L); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.sum(), window)) { assertColumnsAreEqual(expected, result); } } @@ -3141,7 +3141,7 @@ void testWindowLag() { .window(two, negOne).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.max(), window)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.max(), window)) { assertColumnsAreEqual(expected, result); } } @@ -3155,7 +3155,7 @@ void testWindowDynamic() { .window(precedingCol, followingCol).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); ColumnVector expected = ColumnVector.fromLongs(16, 22, 30, 14, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.sum(), window)) { assertColumnsAreEqual(expected, result); } } @@ -3181,7 +3181,7 @@ void testWindowThrowsException() { .minPeriods(1) .orderByColumnIndex(0) .build()) { - arraywindowCol.rollingWindow(Aggregation.sum(), options); + arraywindowCol.rollingWindow(RollingAggregation.sum(), options); } }); } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 360f3c04f5b..3c18c7f3fcd 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -3048,9 +3048,9 @@ void testWindowingCount() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.count().onColumn(3).overWindow(window)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.count().onColumn(3).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); assertColumnsAreEqual(expect, decWindowAggResults.getColumn(0)); @@ -3088,9 +3088,9 @@ void testWindowingMin() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.min().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.min().onColumn(3).overWindow(window)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.min().onColumn(6).overWindow(window)); + .aggregateWindows(RollingAggregation.min().onColumn(6).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6); ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); @@ -3129,9 +3129,9 @@ void testWindowingMax() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.max().onColumn(3).overWindow(window)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.max().onColumn(6).overWindow(window)); + .aggregateWindows(RollingAggregation.max().onColumn(6).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); @@ -3163,7 +3163,7 @@ void testWindowingSum() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.sum().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.sum().onColumn(3).overWindow(window)); ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L)) { assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); } @@ -3199,12 +3199,12 @@ void testWindowingRowNumber() { WindowOptions options = windowBuilder.window(two, one).build(); WindowOptions options1 = windowBuilder.window(two, one).build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(3) .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(6) .overWindow(options1)); @@ -3219,12 +3219,12 @@ void testWindowingRowNumber() { WindowOptions options = windowBuilder.window(three, two).build(); WindowOptions options1 = windowBuilder.window(three, two).build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(3) .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(6) .overWindow(options1)); @@ -3239,12 +3239,12 @@ void testWindowingRowNumber() { WindowOptions options = windowBuilder.window(four, three).build(); WindowOptions options1 = windowBuilder.window(four, three).build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(3) .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(6) .overWindow(options1)); @@ -3259,8 +3259,8 @@ void testWindowingRowNumber() { @Test void testWindowingCollectList() { - Aggregation aggCollectWithNulls = Aggregation.collectList(NullPolicy.INCLUDE); - Aggregation aggCollect = Aggregation.collectList(); + RollingAggregation aggCollectWithNulls = RollingAggregation.collectList(NullPolicy.INCLUDE); + RollingAggregation aggCollect = RollingAggregation.collectList(); try (Scalar two = Scalar.fromInt(2); Scalar one = Scalar.fromInt(1); WindowOptions winOpts = WindowOptions.builder() @@ -3335,12 +3335,12 @@ void testWindowingCollectList() { @Test void testWindowingCollectSet() { - Aggregation aggCollect = Aggregation.collectSet(); - Aggregation aggCollectWithEqNulls = Aggregation.collectSet(NullPolicy.INCLUDE, + RollingAggregation aggCollect = RollingAggregation.collectSet(); + RollingAggregation aggCollectWithEqNulls = RollingAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL); - Aggregation aggCollectWithUnEqNulls = Aggregation.collectSet(NullPolicy.INCLUDE, + RollingAggregation aggCollectWithUnEqNulls = RollingAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); - Aggregation aggCollectWithEqNaNs = Aggregation.collectSet(NullPolicy.INCLUDE, + RollingAggregation aggCollectWithEqNaNs = RollingAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL); try (Scalar two = Scalar.fromInt(2); @@ -3473,22 +3473,22 @@ void testWindowingLead() { Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(two, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(0) .onColumn(3) // Int Agg Column .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(0) .onColumn(6) // Decimal Agg Column .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(0) .onColumn(7) // List Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(0) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3517,22 +3517,22 @@ void testWindowingLead() { Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(zero, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3575,22 +3575,22 @@ null, new StructData(13, "s13"), new StructData(14, "s14"), null, new StructData(-111, "s111"), new StructData(null, "s112"), new StructData(-222, "s222"), new StructData(-333, "s333")); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1, defaultOutput) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1, decDefaultOutput) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1, listDefaultOutput) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1, structDefaultOutput) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3619,22 +3619,22 @@ null, new StructData(13, "s13"), new StructData(14, "s14"), new StructData(-14, Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(zero, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(3) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(3) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(3) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(3) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3694,22 +3694,22 @@ void testWindowingLag() { Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(two, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(0) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(0) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(0) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(0) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3737,22 +3737,22 @@ void testWindowingLag() { Scalar two = Scalar.fromInt(2); WindowOptions options = windowBuilder.window(two, zero).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3794,22 +3794,22 @@ null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData( new StructData(-11, "s11"), null, new StructData(-13, "s13"), new StructData(-14, "s14"), new StructData(-111, "s111"), new StructData(null, "s112"), new StructData(-222, "s222"), new StructData(-333, "s333")); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1, defaultOutput) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1, decDefaultOutput) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1, listDefaultOutput) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1, structDefaultOutput) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3838,22 +3838,22 @@ null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData( Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(one, zero).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(3) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(3) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(3) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(3) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3896,7 +3896,7 @@ void testWindowingMean() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.mean().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.mean().onColumn(3).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedDoubles(6.0d, 5.0d, 5.0d, 5.0d, 8.0d, 8.0d, 7.0d, 6.0d, 4.0d, 4.0d, 4.0d, 6.0d)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -3941,10 +3941,10 @@ void testWindowingOnMultipleDifferentColumns() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows( - Aggregation.sum().onColumn(3).overWindow(window_1), - Aggregation.max().onColumn(3).overWindow(window_1), - Aggregation.sum().onColumn(3).overWindow(window_2), - Aggregation.min().onColumn(2).overWindow(window_3) + RollingAggregation.sum().onColumn(3).overWindow(window_1), + RollingAggregation.max().onColumn(3).overWindow(window_1), + RollingAggregation.sum().onColumn(3).overWindow(window_2), + RollingAggregation.min().onColumn(2).overWindow(window_3) ); ColumnVector expect_0 = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L); ColumnVector expect_1 = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); @@ -3979,8 +3979,8 @@ void testWindowingWithoutGroupByColumns() { .build()) { try (Table windowAggResults = sorted.groupBy().aggregateWindows( - Aggregation.sum().onColumn(1).overWindow(window)); - ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L); + RollingAggregation.sum().onColumn(1).overWindow(window)); + ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L) ) { assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); } @@ -4054,7 +4054,7 @@ void testRangeWindowingCount() { .orderByColumnIndex(orderIndex) .build()) { try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(window)); + RollingAggregation.count().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 2, 4, 4, 4, 4, 4, 4, 5, 5, 3)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4098,7 +4098,7 @@ void testRangeWindowingLead() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lead(1) + .aggregateWindowsOverRanges(RollingAggregation.lead(1) .onColumn(2) .overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { @@ -4144,7 +4144,7 @@ void testRangeWindowingMax() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.max().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 9, 8, 8, 8, 8, 8)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4158,7 +4158,7 @@ void testRangeWindowingMax() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(2).overWindow(window)); + .aggregateWindows(RollingAggregation.max().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 8, 8)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4202,7 +4202,7 @@ void testRangeWindowingRowNumber() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.rowNumber().onColumn(2).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.rowNumber().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4254,12 +4254,12 @@ void testRangeWindowingCountDescendingTimestamps() { .window(preceding_1, following_1) .orderByColumnIndex(orderIndex) .orderByDescending() - .build();) { + .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(window_0), - Aggregation.sum().onColumn(2).overWindow(window_1)); + RollingAggregation.count().onColumn(2).overWindow(window_0), + RollingAggregation.sum().onColumn(2).overWindow(window_1)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 5, 5, 5); ColumnVector expect_1 = ColumnVector.fromBoxedLongs(7L, 13L, 13L, 22L, 7L, 24L, 24L, 26L, 8L, 8L, 14L, 28L, 28L)) { assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); @@ -4303,7 +4303,7 @@ void testRangeWindowingWithoutGroupByColumns() { .build();) { try (Table windowAggResults = sorted.groupBy() - .aggregateWindowsOverRanges(Aggregation.count().onColumn(1).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.count().onColumn(1).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 6, 6, 6, 6, 7, 7, 6, 6, 5, 5, 3)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4333,7 +4333,7 @@ void testRangeWindowingOrderByUnsupportedDataTypeExceptions() { assertThrows(IllegalArgumentException.class, () -> table .groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(rangeBasedWindow))); + .aggregateWindowsOverRanges(RollingAggregation.max().onColumn(2).overWindow(rangeBasedWindow))); } } } @@ -4353,7 +4353,7 @@ void testInvalidWindowTypeExceptions() { .minPeriods(1) .window(one, one) .build()) { - assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindowsOverRanges(Aggregation.max().onColumn(3).overWindow(rowBasedWindow))); + assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindowsOverRanges(RollingAggregation.max().onColumn(3).overWindow(rowBasedWindow))); } try (WindowOptions rangeBasedWindow = WindowOptions.builder() @@ -4361,7 +4361,7 @@ void testInvalidWindowTypeExceptions() { .window(one, one) .orderByColumnIndex(2) .build()) { - assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindows(Aggregation.max().onColumn(3).overWindow(rangeBasedWindow))); + assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindows(RollingAggregation.max().onColumn(3).overWindow(rangeBasedWindow))); } } } @@ -4399,7 +4399,7 @@ void testRangeWindowingCountUnboundedPreceding() { .build();) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.count().onColumn(2).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.count().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4475,11 +4475,11 @@ void testRangeWindowingCountUnboundedASCWithNullsFirst() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 5, 5, 6, 2, 2, 4, 4, 6, 6, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 3, 1, 7, 7, 5, 5, 3, 3, 1); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); @@ -4570,11 +4570,11 @@ void testRangeWindowingCountUnboundedDESCWithNullsFirst() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 6, 6, 2, 2, 3, 5, 5, 7, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 2, 7, 7, 5, 4, 4, 2, 2); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); @@ -4658,11 +4658,11 @@ void testRangeWindowingCountUnboundedASCWithNullsLast() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(2, 2, 3, 6, 6, 6, 2, 2, 4, 4, 5, 7, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 4, 3, 3, 3, 7, 7, 5, 5, 3, 2, 2); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); @@ -4752,11 +4752,11 @@ void testRangeWindowingCountUnboundedDESCWithNullsLast() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(1, 3, 3, 6, 6, 6, 1, 3, 3, 5, 5, 7, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 5, 5, 3, 3, 3, 7, 6, 6, 4, 4, 2, 2); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); From 0c4cda4a223acaa3dcdd7dc1b7bb3b85747034cb Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 30 Jul 2021 14:49:08 -0500 Subject: [PATCH 2/7] Grouped Scan --- .../rapids/cudf/GroupByScanAggregation.java | 118 ++++++++++++++++++ .../cudf/GroupByScanAggregationOnColumn.java | 64 ++++++++++ .../ai/rapids/cudf/RollingAggregation.java | 5 +- java/src/main/java/ai/rapids/cudf/Table.java | 6 +- .../test/java/ai/rapids/cudf/TableTest.java | 12 +- 5 files changed, 193 insertions(+), 12 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java create mode 100644 java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java diff --git a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java new file mode 100644 index 00000000000..97250a71486 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java @@ -0,0 +1,118 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a grouped scan. + */ +public class GroupByScanAggregation { + private final Aggregation wrapped; + + private GroupByScanAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + Aggregation getWrapped() { + return wrapped; + } + + /** + * Add a column to the Aggregation so it can be used on a specific column of data. + * @param columnIndex the index of the column to operate on. + */ + public GroupByScanAggregationOnColumn onColumn(int columnIndex) { + return new GroupByScanAggregationOnColumn(this, columnIndex); + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof GroupByScanAggregation) { + GroupByScanAggregation o = (GroupByScanAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Sum Aggregation + */ + public static GroupByScanAggregation sum() { + return new GroupByScanAggregation(Aggregation.sum()); + } + + + /** + * Product Aggregation. + */ + public static GroupByScanAggregation product() { + return new GroupByScanAggregation(Aggregation.product()); + } + + /** + * Min Aggregation + */ + public static GroupByScanAggregation min() { + return new GroupByScanAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static GroupByScanAggregation max() { + return new GroupByScanAggregation(Aggregation.max()); + } + + /** + * Count number of elements. + * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values + * should be counted. + */ + public static GroupByScanAggregation count(NullPolicy nullPolicy) { + return new GroupByScanAggregation(Aggregation.count(nullPolicy)); + } + + /** + * Get the row's ranking. + */ + public static GroupByScanAggregation rank() { + return new GroupByScanAggregation(Aggregation.rank()); + } + + /** + * Get the row's dense ranking. + */ + public static GroupByScanAggregation denseRank() { + return new GroupByScanAggregation(Aggregation.denseRank()); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java new file mode 100644 index 00000000000..227cd58ae8c --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java @@ -0,0 +1,64 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * A GroupByScanAggregation for a specific column in a table. + */ +public class GroupByScanAggregationOnColumn { + protected final GroupByScanAggregation wrapped; + protected final int columnIndex; + + GroupByScanAggregationOnColumn(GroupByScanAggregation wrapped, int columnIndex) { + this.wrapped = wrapped; + this.columnIndex = columnIndex; + } + + public int getColumnIndex() { + return columnIndex; + } + + @Override + public int hashCode() { + return 31 * wrapped.hashCode() + columnIndex; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof GroupByScanAggregationOnColumn) { + GroupByScanAggregationOnColumn o = (GroupByScanAggregationOnColumn) other; + return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; + } + return false; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + GroupByScanAggregation getWrapped() { + return wrapped; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java index d9f026d79c5..b7e56606fb5 100644 --- a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java @@ -19,8 +19,7 @@ package ai.rapids.cudf; /** - * Used to tag an aggregation as something that is compatible with rolling window operations. - * Do not try to implement this yourself + * An aggregation that can be used on rolling windows. */ public class RollingAggregation { private final Aggregation wrapped; @@ -88,7 +87,7 @@ public static RollingAggregation max() { * Count number of valid, a.k.a. non-null, elements. */ public static RollingAggregation count() { - return count(NullPolicy.EXCLUDE); + return new RollingAggregation(Aggregation.count()); } /** diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 887c22e388a..746bef6e939 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2808,7 +2808,7 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate } } - public Table scan(AggregationOnColumn... aggregates) { + public Table scan(GroupByScanAggregationOnColumn... aggregates) { assert aggregates != null; // To improve performance and memory we want to remove duplicate operations @@ -2821,9 +2821,9 @@ public Table scan(AggregationOnColumn... aggregates) { int keysLength = operation.indices.length; int totalOps = 0; for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) { - AggregationOnColumn agg = aggregates[outputIndex]; + GroupByScanAggregationOnColumn agg = aggregates[outputIndex]; ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps()); - totalOps += ops.add(agg.getWrapped(), outputIndex + keysLength); + totalOps += ops.add(agg.getWrapped().getWrapped(), outputIndex + keysLength); } int[] aggColumnIndexes = new int[totalOps]; long[] aggOperationInstances = new long[totalOps]; diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 3c18c7f3fcd..b6b8f197630 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2902,12 +2902,12 @@ void testGroupByScan() { .withKeysSorted(true) .withKeysDescending(false, false) .build(), 0, 1) - .scan(Aggregation.sum().onColumn(2), - Aggregation.count(NullPolicy.INCLUDE).onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.max().onColumn(2), - Aggregation.rank().onColumn(3), - Aggregation.denseRank().onColumn(3)); + .scan(GroupByScanAggregation.sum().onColumn(2), + GroupByScanAggregation.count(NullPolicy.INCLUDE).onColumn(2), + GroupByScanAggregation.min().onColumn(2), + GroupByScanAggregation.max().onColumn(2), + GroupByScanAggregation.rank().onColumn(3), + GroupByScanAggregation.denseRank().onColumn(3)); Table expected = new Table.TestBuilder() .column( "1", "1", "1", "1", "1", "1", "1", "2", "2", "2", "2") .column( 0, 1, 3, 3, 5, 5, 5, 5, 5, 5, 5) From 9a33623580faf1558920e0db688647fbd8ecefa3 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 30 Jul 2021 15:10:06 -0500 Subject: [PATCH 3/7] Scan Aggregation --- .../main/java/ai/rapids/cudf/ColumnView.java | 8 +- .../java/ai/rapids/cudf/ScanAggregation.java | 101 ++++++++++++++++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 98 ++++++++--------- 3 files changed, 150 insertions(+), 57 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/ScanAggregation.java diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 1e2790d42c5..257ec6d930d 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1419,7 +1419,7 @@ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions opt * This is just a convenience method for an inclusive scan with a SUM aggregation. */ public final ColumnVector prefixSum() { - return scan(Aggregation.sum()); + return scan(ScanAggregation.sum()); } /** @@ -1430,7 +1430,7 @@ public final ColumnVector prefixSum() { * null policy too. Currently none of those aggregations are supported so * it is undefined how they would interact with each other. */ - public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullPolicy nullPolicy) { + public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType, NullPolicy nullPolicy) { long nativeId = aggregation.createNativeInstance(); try { return new ColumnVector(scan(getNativeView(), nativeId, @@ -1445,7 +1445,7 @@ public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullP * @param aggregation the aggregation to perform * @param scanType should the scan be inclusive, include the current row, or exclusive. */ - public final ColumnVector scan(Aggregation aggregation, ScanType scanType) { + public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType) { return scan(aggregation, scanType, NullPolicy.EXCLUDE); } @@ -1453,7 +1453,7 @@ public final ColumnVector scan(Aggregation aggregation, ScanType scanType) { * Computes an inclusive scan for a column that excludes nulls. * @param aggregation the aggregation to perform */ - public final ColumnVector scan(Aggregation aggregation) { + public final ColumnVector scan(ScanAggregation aggregation) { return scan(aggregation, ScanType.INCLUSIVE, NullPolicy.EXCLUDE); } diff --git a/java/src/main/java/ai/rapids/cudf/ScanAggregation.java b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java new file mode 100644 index 00000000000..ab1da3c9e39 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java @@ -0,0 +1,101 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a scan. + */ +public class ScanAggregation { + private final Aggregation wrapped; + + private ScanAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + Aggregation getWrapped() { + return wrapped; + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof ScanAggregation) { + ScanAggregation o = (ScanAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Sum Aggregation + */ + public static ScanAggregation sum() { + return new ScanAggregation(Aggregation.sum()); + } + + + /** + * Product Aggregation. + */ + public static ScanAggregation product() { + return new ScanAggregation(Aggregation.product()); + } + + /** + * Min Aggregation + */ + public static ScanAggregation min() { + return new ScanAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static ScanAggregation max() { + return new ScanAggregation(Aggregation.max()); + } + + /** + * Get the row's ranking. + */ + public static ScanAggregation rank() { + return new ScanAggregation(Aggregation.rank()); + } + + /** + * Get the row's dense ranking. + */ + public static ScanAggregation denseRank() { + return new ScanAggregation(Aggregation.denseRank()); + } +} diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 59efa14b3f9..4856071e296 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2899,24 +2899,22 @@ void testPrefixSum() { @Test void testScanSum() { try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled -// try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } - - try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } + + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, 6, 11, 19, 29)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, 3, 3, 6, 11, 19)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, 3, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, null, 3, 6, 11, 19)) { assertColumnsAreEqual(expected, result); } @@ -2925,25 +2923,23 @@ void testScanSum() { @Test void testScanMax() { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { -// try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, 2, 2, 3, 5, 8)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, 2, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, null, 2, 3, 5, 8)) { assertColumnsAreEqual(expected, result); } @@ -2952,25 +2948,23 @@ void testScanMax() { @Test void testScanMin() { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { -// try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, 1, 1, 1, 1)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, 1, 1, 1, 1, 1)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, 1, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, null, 1, 1, 1, 1)) { assertColumnsAreEqual(expected, result); } @@ -2979,25 +2973,23 @@ void testScanMin() { @Test void testScanProduct() { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { -// try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 6, 30, 240, 2400)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, 2, 2, 6, 30, 240)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, 2, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, 2, 6, 30, 240)) { assertColumnsAreEqual(expected, result); } @@ -3011,13 +3003,13 @@ void testScanRank() { ColumnVector struct_order = ColumnVector.makeStruct(col1, col2); ColumnVector expected = ColumnVector.fromBoxedInts( 1, 1, 3, 4, 5, 6, 7, 7, 9, 9, 11, 12)) { - try (ColumnVector result = struct_order.scan(Aggregation.rank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.rank(), ScanType.INCLUSIVE, NullPolicy.INCLUDE)) { assertColumnsAreEqual(expected, result); } // Exclude should have identical results - try (ColumnVector result = struct_order.scan(Aggregation.rank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.rank(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE) ) { assertColumnsAreEqual(expected, result); @@ -3034,13 +3026,13 @@ void testScanDenseRank() { ColumnVector struct_order = ColumnVector.makeStruct(col1, col2); ColumnVector expected = ColumnVector.fromBoxedInts( 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8, 9)) { - try (ColumnVector result = struct_order.scan(Aggregation.denseRank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.denseRank(), ScanType.INCLUSIVE, NullPolicy.INCLUDE)) { assertColumnsAreEqual(expected, result); } // Exclude should have identical results - try (ColumnVector result = struct_order.scan(Aggregation.denseRank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.denseRank(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE)) { assertColumnsAreEqual(expected, result); } From 707e4ed23fae7d65c7f3fbb1b4f49ff5a06f229f Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 30 Jul 2021 15:35:44 -0500 Subject: [PATCH 4/7] ReductionAggregation --- .../main/java/ai/rapids/cudf/ColumnView.java | 28 +-- .../ai/rapids/cudf/ReductionAggregation.java | 212 ++++++++++++++++ .../java/ai/rapids/cudf/ScanAggregation.java | 1 - .../java/ai/rapids/cudf/ReductionTest.java | 230 +++++++++--------- 4 files changed, 341 insertions(+), 130 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/ReductionAggregation.java diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 257ec6d930d..55bd5ec5ff9 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1135,7 +1135,7 @@ public Scalar sum() { * of the specified type. */ public Scalar sum(DType outType) { - return reduce(Aggregation.sum(), outType); + return reduce(ReductionAggregation.sum(), outType); } /** @@ -1143,7 +1143,7 @@ public Scalar sum(DType outType) { * of the same type as this column. */ public Scalar min() { - return reduce(Aggregation.min(), type); + return reduce(ReductionAggregation.min(), type); } /** @@ -1160,7 +1160,7 @@ public Scalar min(DType outType) { return tmp.min(outType); } } - return reduce(Aggregation.min(), outType); + return reduce(ReductionAggregation.min(), outType); } /** @@ -1168,7 +1168,7 @@ public Scalar min(DType outType) { * of the same type as this column. */ public Scalar max() { - return reduce(Aggregation.max(), type); + return reduce(ReductionAggregation.max(), type); } /** @@ -1185,7 +1185,7 @@ public Scalar max(DType outType) { return tmp.max(outType); } } - return reduce(Aggregation.max(), outType); + return reduce(ReductionAggregation.max(), outType); } /** @@ -1201,7 +1201,7 @@ public Scalar product() { * of the specified type. */ public Scalar product(DType outType) { - return reduce(Aggregation.product(), outType); + return reduce(ReductionAggregation.product(), outType); } /** @@ -1217,7 +1217,7 @@ public Scalar sumOfSquares() { * scalar of the specified type. */ public Scalar sumOfSquares(DType outType) { - return reduce(Aggregation.sumOfSquares(), outType); + return reduce(ReductionAggregation.sumOfSquares(), outType); } /** @@ -1241,7 +1241,7 @@ public Scalar mean() { * types are currently supported. */ public Scalar mean(DType outType) { - return reduce(Aggregation.mean(), outType); + return reduce(ReductionAggregation.mean(), outType); } /** @@ -1265,7 +1265,7 @@ public Scalar variance() { * types are currently supported. */ public Scalar variance(DType outType) { - return reduce(Aggregation.variance(), outType); + return reduce(ReductionAggregation.variance(), outType); } /** @@ -1290,7 +1290,7 @@ public Scalar standardDeviation() { * types are currently supported. */ public Scalar standardDeviation(DType outType) { - return reduce(Aggregation.standardDeviation(), outType); + return reduce(ReductionAggregation.standardDeviation(), outType); } /** @@ -1309,7 +1309,7 @@ public Scalar any() { * Null values are skipped. */ public Scalar any(DType outType) { - return reduce(Aggregation.any(), outType); + return reduce(ReductionAggregation.any(), outType); } /** @@ -1330,7 +1330,7 @@ public Scalar all() { */ @Deprecated public Scalar all(DType outType) { - return reduce(Aggregation.all(), outType); + return reduce(ReductionAggregation.all(), outType); } /** @@ -1343,7 +1343,7 @@ public Scalar all(DType outType) { * empty or the reduction operation fails then the * {@link Scalar#isValid()} method of the result will return false. */ - public Scalar reduce(Aggregation aggregation) { + public Scalar reduce(ReductionAggregation aggregation) { return reduce(aggregation, type); } @@ -1360,7 +1360,7 @@ public Scalar reduce(Aggregation aggregation) { * empty or the reduction operation fails then the * {@link Scalar#isValid()} method of the result will return false. */ - public Scalar reduce(Aggregation aggregation, DType outType) { + public Scalar reduce(ReductionAggregation aggregation, DType outType) { long nativeId = aggregation.createNativeInstance(); try { return new Scalar(outType, reduce(getNativeView(), nativeId, outType.typeId.getNativeId(), outType.getScale())); diff --git a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java new file mode 100644 index 00000000000..ad96b93c400 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java @@ -0,0 +1,212 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a reduce. + */ +public class ReductionAggregation { + private final Aggregation wrapped; + + private ReductionAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + Aggregation getWrapped() { + return wrapped; + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof ReductionAggregation) { + ReductionAggregation o = (ReductionAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Sum Aggregation + */ + public static ReductionAggregation sum() { + return new ReductionAggregation(Aggregation.sum()); + } + + /** + * Product Aggregation. + */ + public static ReductionAggregation product() { + return new ReductionAggregation(Aggregation.product()); + } + + /** + * Min Aggregation + */ + public static ReductionAggregation min() { + return new ReductionAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static ReductionAggregation max() { + return new ReductionAggregation(Aggregation.max()); + } + + /** + * Any reduction. Produces a true or 1, depending on the output type, + * if any of the elements in the range are true or non-zero, otherwise produces a false or 0. + * Null values are skipped. + */ + public static ReductionAggregation any() { + return new ReductionAggregation(Aggregation.any()); + } + + /** + * All reduction. Produces true or 1, depending on the output type, if all of the elements in + * the range are true or non-zero, otherwise produces a false or 0. + * Null values are skipped. + */ + public static ReductionAggregation all() { + return new ReductionAggregation(Aggregation.all()); + } + + + /** + * Sum of squares reduction. + */ + public static ReductionAggregation sumOfSquares() { + return new ReductionAggregation(Aggregation.sumOfSquares()); + } + + /** + * Arithmetic mean reduction. + */ + public static ReductionAggregation mean() { + return new ReductionAggregation(Aggregation.mean()); + } + + + /** + * Variance aggregation with 1 as the delta degrees of freedom. + */ + public static ReductionAggregation variance() { + return new ReductionAggregation(Aggregation.variance()); + } + + /** + * Variance aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of variance is + * N - ddof, where N is the population size. + */ + public static ReductionAggregation variance(int ddof) { + return new ReductionAggregation(Aggregation.variance(ddof)); + } + + /** + * Standard deviation aggregation with 1 as the delta degrees of freedom. + */ + public static ReductionAggregation standardDeviation() { + return new ReductionAggregation(Aggregation.standardDeviation()); + } + + /** + * Standard deviation aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of std is + * N - ddof, where N is the population size. + */ + public static ReductionAggregation standardDeviation(int ddof) { + return new ReductionAggregation(Aggregation.standardDeviation(ddof)); + } + + + /** + * Median reduction. + */ + public static ReductionAggregation median() { + return new ReductionAggregation(Aggregation.median()); + } + + /** + * Aggregate to compute the specified quantiles. Uses linear interpolation by default. + */ + public static ReductionAggregation quantile(double ... quantiles) { + return new ReductionAggregation(Aggregation.quantile(quantiles)); + } + + /** + * Aggregate to compute various quantiles. + */ + public static ReductionAggregation quantile(QuantileMethod method, double ... quantiles) { + return new ReductionAggregation(Aggregation.quantile(method, quantiles)); + } + + + /** + * Number of unique, non-null, elements. + */ + public static ReductionAggregation nunique() { + return new ReductionAggregation(Aggregation.nunique()); + } + + /** + * Number of unique elements. + * @param nullPolicy INCLUDE if nulls should be counted else EXCLUDE. If nulls are counted they + * compare as equal so multiple null values in a range would all only + * increase the count by 1. + */ + public static ReductionAggregation nunique(NullPolicy nullPolicy) { + return new ReductionAggregation(Aggregation.nunique(nullPolicy)); + } + + /** + * Get the nth, non-null, element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + */ + public static ReductionAggregation nth(int offset) { + return new ReductionAggregation(Aggregation.nth(offset)); + } + + /** + * Get the nth element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they + * should be skipped. + */ + public static ReductionAggregation nth(int offset, NullPolicy nullPolicy) { + return new ReductionAggregation(Aggregation.nth(offset, nullPolicy)); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ScanAggregation.java b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java index ab1da3c9e39..bd19546e5ef 100644 --- a/java/src/main/java/ai/rapids/cudf/ScanAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java @@ -63,7 +63,6 @@ public static ScanAggregation sum() { return new ScanAggregation(Aggregation.sum()); } - /** * Product Aggregation. */ diff --git a/java/src/test/java/ai/rapids/cudf/ReductionTest.java b/java/src/test/java/ai/rapids/cudf/ReductionTest.java index 17b9ec3556f..2b26597c8f7 100644 --- a/java/src/test/java/ai/rapids/cudf/ReductionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ReductionTest.java @@ -43,17 +43,17 @@ class ReductionTest extends CudfTestBase { Aggregation.Kind.ANY, Aggregation.Kind.ALL); - private static Scalar buildExpectedScalar(Aggregation op, DType baseType, Object expectedObject) { + private static Scalar buildExpectedScalar(ReductionAggregation op, DType baseType, Object expectedObject) { if (expectedObject == null) { return Scalar.fromNull(baseType); } - if (FLOAT_REDUCTIONS.contains(op.kind)) { + if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { if (baseType.equals(DType.FLOAT32)) { return Scalar.fromFloat((Float) expectedObject); } return Scalar.fromDouble((Double) expectedObject); } - if (BOOL_REDUCTIONS.contains(op.kind)) { + if (BOOL_REDUCTIONS.contains(op.getWrapped().kind)) { return Scalar.fromBool((Boolean) expectedObject); } switch (baseType.typeId) { @@ -88,165 +88,165 @@ private static Scalar buildExpectedScalar(Aggregation op, DType baseType, Object private static Stream createBooleanParams() { Boolean[] vals = new Boolean[]{true, true, null, false, true, false, null}; return Stream.of( - Arguments.of(Aggregation.sum(), new Boolean[0], null, 0.), - Arguments.of(Aggregation.sum(), new Boolean[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, true, 0.), - Arguments.of(Aggregation.min(), vals, false, 0.), - Arguments.of(Aggregation.max(), vals, true, 0.), - Arguments.of(Aggregation.product(), vals, false, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, true, 0.), - Arguments.of(Aggregation.mean(), vals, 0.6, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 0.5477225575051662, DELTAD), - Arguments.of(Aggregation.variance(), vals, 0.3, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, false, 0.) + Arguments.of(ReductionAggregation.sum(), new Boolean[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Boolean[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, true, 0.), + Arguments.of(ReductionAggregation.min(), vals, false, 0.), + Arguments.of(ReductionAggregation.max(), vals, true, 0.), + Arguments.of(ReductionAggregation.product(), vals, false, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, true, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 0.6, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 0.5477225575051662, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 0.3, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, false, 0.) ); } private static Stream createByteParams() { Byte[] vals = new Byte[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.sum(), new Byte[0], null, 0.), - Arguments.of(Aggregation.sum(), new Byte[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, (byte) 83, 0.), - Arguments.of(Aggregation.min(), vals, (byte) -1, 0.), - Arguments.of(Aggregation.max(), vals, (byte) 123, 0.), - Arguments.of(Aggregation.product(), vals, (byte) 160, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, (byte) 47, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Byte[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Byte[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, (byte) 83, 0.), + Arguments.of(ReductionAggregation.min(), vals, (byte) -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, (byte) 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, (byte) 160, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, (byte) 47, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.) ); } private static Stream createShortParams() { Short[] vals = new Short[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.sum(), new Short[0], null, 0.), - Arguments.of(Aggregation.sum(), new Short[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, (short) 339, 0.), - Arguments.of(Aggregation.min(), vals, (short) -1, 0.), - Arguments.of(Aggregation.max(), vals, (short) 123, 0.), - Arguments.of(Aggregation.product(), vals, (short) -22624, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, (short) 31279, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Short[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Short[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, (short) 339, 0.), + Arguments.of(ReductionAggregation.min(), vals, (short) -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, (short) 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, (short) -22624, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, (short) 31279, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.) ); } private static Stream createIntParams() { Integer[] vals = new Integer[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.sum(), new Integer[0], null, 0.), - Arguments.of(Aggregation.sum(), new Integer[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, 339, 0.), - Arguments.of(Aggregation.min(), vals, -1, 0.), - Arguments.of(Aggregation.max(), vals, 123, 0.), - Arguments.of(Aggregation.product(), vals, -258300000, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Integer[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Integer[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, 339, 0.), + Arguments.of(ReductionAggregation.min(), vals, -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, -258300000, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.) ); } private static Stream createLongParams() { Long[] vals = new Long[]{-1L, 7L, 123L, null, 50L, 60L, 100L}; return Stream.of( - Arguments.of(Aggregation.sum(), new Long[0], null, 0.), - Arguments.of(Aggregation.sum(), new Long[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, 339L, 0.), - Arguments.of(Aggregation.min(), vals, -1L, 0.), - Arguments.of(Aggregation.max(), vals, 123L, 0.), - Arguments.of(Aggregation.product(), vals, -258300000L, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279L, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.), - Arguments.of(Aggregation.quantile(0.5), vals, 55.0, DELTAD), - Arguments.of(Aggregation.quantile(0.9), vals, 111.5, DELTAD) + Arguments.of(ReductionAggregation.sum(), new Long[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Long[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, 339L, 0.), + Arguments.of(ReductionAggregation.min(), vals, -1L, 0.), + Arguments.of(ReductionAggregation.max(), vals, 123L, 0.), + Arguments.of(ReductionAggregation.product(), vals, -258300000L, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279L, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.), + Arguments.of(ReductionAggregation.quantile(0.5), vals, 55.0, DELTAD), + Arguments.of(ReductionAggregation.quantile(0.9), vals, 111.5, DELTAD) ); } private static Stream createFloatParams() { Float[] vals = new Float[]{-1f, 7f, 123f, null, 50f, 60f, 100f}; return Stream.of( - Arguments.of(Aggregation.sum(), new Float[0], null, 0f), - Arguments.of(Aggregation.sum(), new Float[]{null, null, null}, null, 0f), - Arguments.of(Aggregation.sum(), vals, 339f, 0f), - Arguments.of(Aggregation.min(), vals, -1f, 0f), - Arguments.of(Aggregation.max(), vals, 123f, 0f), - Arguments.of(Aggregation.product(), vals, -258300000f, 0f), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279f, 0f), - Arguments.of(Aggregation.mean(), vals, 56.5f, DELTAF), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839f, DELTAF), - Arguments.of(Aggregation.variance(), vals, 2425.1f, DELTAF), - Arguments.of(Aggregation.any(), vals, true, 0f), - Arguments.of(Aggregation.all(), vals, true, 0f) + Arguments.of(ReductionAggregation.sum(), new Float[0], null, 0f), + Arguments.of(ReductionAggregation.sum(), new Float[]{null, null, null}, null, 0f), + Arguments.of(ReductionAggregation.sum(), vals, 339f, 0f), + Arguments.of(ReductionAggregation.min(), vals, -1f, 0f), + Arguments.of(ReductionAggregation.max(), vals, 123f, 0f), + Arguments.of(ReductionAggregation.product(), vals, -258300000f, 0f), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279f, 0f), + Arguments.of(ReductionAggregation.mean(), vals, 56.5f, DELTAF), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839f, DELTAF), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1f, DELTAF), + Arguments.of(ReductionAggregation.any(), vals, true, 0f), + Arguments.of(ReductionAggregation.all(), vals, true, 0f) ); } private static Stream createDoubleParams() { Double[] vals = new Double[]{-1., 7., 123., null, 50., 60., 100.}; return Stream.of( - Arguments.of(Aggregation.sum(), new Double[0], null, 0.), - Arguments.of(Aggregation.sum(), new Double[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, 339., 0.), - Arguments.of(Aggregation.min(), vals, -1., 0.), - Arguments.of(Aggregation.max(), vals, 123., 0.), - Arguments.of(Aggregation.product(), vals, -258300000., 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279., 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.), - Arguments.of(Aggregation.quantile(0.5), vals, 55.0, DELTAD), - Arguments.of(Aggregation.quantile(0.9), vals, 111.5, DELTAD) + Arguments.of(ReductionAggregation.sum(), new Double[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Double[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, 339., 0.), + Arguments.of(ReductionAggregation.min(), vals, -1., 0.), + Arguments.of(ReductionAggregation.max(), vals, 123., 0.), + Arguments.of(ReductionAggregation.product(), vals, -258300000., 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279., 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.), + Arguments.of(ReductionAggregation.quantile(0.5), vals, 55.0, DELTAD), + Arguments.of(ReductionAggregation.quantile(0.9), vals, 111.5, DELTAD) ); } private static Stream createTimestampDaysParams() { Integer[] vals = new Integer[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.max(), new Integer[0], null), - Arguments.of(Aggregation.max(), new Integer[]{null, null, null}, null), - Arguments.of(Aggregation.max(), vals, 123), - Arguments.of(Aggregation.min(), vals, -1) + Arguments.of(ReductionAggregation.max(), new Integer[0], null), + Arguments.of(ReductionAggregation.max(), new Integer[]{null, null, null}, null), + Arguments.of(ReductionAggregation.max(), vals, 123), + Arguments.of(ReductionAggregation.min(), vals, -1) ); } private static Stream createTimestampResolutionParams() { Long[] vals = new Long[]{-1L, 7L, 123L, null, 50L, 60L, 100L}; return Stream.of( - Arguments.of(Aggregation.max(), new Long[0], null), - Arguments.of(Aggregation.max(), new Long[]{null, null, null}, null), - Arguments.of(Aggregation.min(), vals, -1L), - Arguments.of(Aggregation.max(), vals, 123L) + Arguments.of(ReductionAggregation.max(), new Long[0], null), + Arguments.of(ReductionAggregation.max(), new Long[]{null, null, null}, null), + Arguments.of(ReductionAggregation.min(), vals, -1L), + Arguments.of(ReductionAggregation.max(), vals, 123L) ); } - private static void assertEqualsDelta(Aggregation op, Scalar expected, Scalar result, + private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, Double percentage) { - if (FLOAT_REDUCTIONS.contains(op.kind)) { + if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getDouble(), result.getDouble(), percentage); } else { assertEquals(expected, result); } } - private static void assertEqualsDelta(Aggregation op, Scalar expected, Scalar result, + private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, Float percentage) { - if (FLOAT_REDUCTIONS.contains(op.kind)) { + if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getFloat(), result.getFloat(), percentage); } else { assertEquals(expected, result); @@ -255,7 +255,7 @@ private static void assertEqualsDelta(Aggregation op, Scalar expected, Scalar re @ParameterizedTest @MethodSource("createBooleanParams") - void testBoolean(Aggregation op, Boolean[] values, Object expectedObject, Double delta) { + void testBoolean(ReductionAggregation op, Boolean[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.BOOL8, expectedObject); ColumnVector v = ColumnVector.fromBoxedBooleans(values); Scalar result = v.reduce(op, expected.getType())) { @@ -265,7 +265,7 @@ void testBoolean(Aggregation op, Boolean[] values, Object expectedObject, Double @ParameterizedTest @MethodSource("createByteParams") - void testByte(Aggregation op, Byte[] values, Object expectedObject, Double delta) { + void testByte(ReductionAggregation op, Byte[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT8, expectedObject); ColumnVector v = ColumnVector.fromBoxedBytes(values); Scalar result = v.reduce(op, expected.getType())) { @@ -275,7 +275,7 @@ void testByte(Aggregation op, Byte[] values, Object expectedObject, Double delta @ParameterizedTest @MethodSource("createShortParams") - void testShort(Aggregation op, Short[] values, Object expectedObject, Double delta) { + void testShort(ReductionAggregation op, Short[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT16, expectedObject); ColumnVector v = ColumnVector.fromBoxedShorts(values); Scalar result = v.reduce(op, expected.getType())) { @@ -285,7 +285,7 @@ void testShort(Aggregation op, Short[] values, Object expectedObject, Double del @ParameterizedTest @MethodSource("createIntParams") - void testInt(Aggregation op, Integer[] values, Object expectedObject, Double delta) { + void testInt(ReductionAggregation op, Integer[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT32, expectedObject); ColumnVector v = ColumnVector.fromBoxedInts(values); Scalar result = v.reduce(op, expected.getType())) { @@ -295,7 +295,7 @@ void testInt(Aggregation op, Integer[] values, Object expectedObject, Double del @ParameterizedTest @MethodSource("createLongParams") - void testLong(Aggregation op, Long[] values, Object expectedObject, Double delta) { + void testLong(ReductionAggregation op, Long[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT64, expectedObject); ColumnVector v = ColumnVector.fromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -305,7 +305,7 @@ void testLong(Aggregation op, Long[] values, Object expectedObject, Double delta @ParameterizedTest @MethodSource("createFloatParams") - void testFloat(Aggregation op, Float[] values, Object expectedObject, Float delta) { + void testFloat(ReductionAggregation op, Float[] values, Object expectedObject, Float delta) { try (Scalar expected = buildExpectedScalar(op, DType.FLOAT32, expectedObject); ColumnVector v = ColumnVector.fromBoxedFloats(values); Scalar result = v.reduce(op, expected.getType())) { @@ -315,7 +315,7 @@ void testFloat(Aggregation op, Float[] values, Object expectedObject, Float delt @ParameterizedTest @MethodSource("createDoubleParams") - void testDouble(Aggregation op, Double[] values, Object expectedObject, Double delta) { + void testDouble(ReductionAggregation op, Double[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.FLOAT64, expectedObject); ColumnVector v = ColumnVector.fromBoxedDoubles(values); Scalar result = v.reduce(op, expected.getType())) { @@ -325,7 +325,7 @@ void testDouble(Aggregation op, Double[] values, Object expectedObject, Double d @ParameterizedTest @MethodSource("createTimestampDaysParams") - void testTimestampDays(Aggregation op, Integer[] values, Object expectedObject) { + void testTimestampDays(ReductionAggregation op, Integer[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_DAYS, expectedObject); ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(values); Scalar result = v.reduce(op, expected.getType())) { @@ -335,7 +335,7 @@ void testTimestampDays(Aggregation op, Integer[] values, Object expectedObject) @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampSeconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampSeconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_SECONDS, expectedObject); ColumnVector v = ColumnVector.timestampSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -345,7 +345,7 @@ void testTimestampSeconds(Aggregation op, Long[] values, Object expectedObject) @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampMilliseconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampMilliseconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_MILLISECONDS, expectedObject); ColumnVector v = ColumnVector.timestampMilliSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -355,7 +355,7 @@ void testTimestampMilliseconds(Aggregation op, Long[] values, Object expectedObj @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampMicroseconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampMicroseconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_MICROSECONDS, expectedObject); ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -365,7 +365,7 @@ void testTimestampMicroseconds(Aggregation op, Long[] values, Object expectedObj @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampNanoseconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampNanoseconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_NANOSECONDS, expectedObject); ColumnVector v = ColumnVector.timestampNanoSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { From bfdd972d8b24ba34124205427841eb76aa85097b Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 30 Jul 2021 17:17:12 -0500 Subject: [PATCH 5/7] Group by aggregtion --- .../main/java/ai/rapids/cudf/Aggregation.java | 125 ++++++++---------- ...n.java => GroupByAggregationOnColumn.java} | 14 +- java/src/main/java/ai/rapids/cudf/Table.java | 6 +- .../test/java/ai/rapids/cudf/TableTest.java | 102 +++++++------- 4 files changed, 119 insertions(+), 128 deletions(-) rename java/src/main/java/ai/rapids/cudf/{AggregationOnColumn.java => GroupByAggregationOnColumn.java} (74%) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 62ca27c732e..0480dc465ab 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -24,7 +24,7 @@ * Represents an aggregation operation. Please note that not all aggregations work, or even make * sense in all types of aggregation operations. */ -public abstract class Aggregation { +abstract class Aggregation { static { NativeDepsLoader.loadNativeDeps(); } @@ -102,7 +102,7 @@ public boolean equals(Object other) { } } - public static final class NthAggregation extends Aggregation { + static final class NthAggregation extends Aggregation { private final int offset; private final NullPolicy nullPolicy; @@ -275,7 +275,7 @@ long getDefaultOutput() { } } - public static final class CollectListAggregation extends Aggregation { + static final class CollectListAggregation extends Aggregation { private final NullPolicy nullPolicy; private CollectListAggregation(NullPolicy nullPolicy) { @@ -305,7 +305,7 @@ public boolean equals(Object other) { } } - public static final class CollectSetAggregation extends Aggregation { + static final class CollectSetAggregation extends Aggregation { private final NullPolicy nullPolicy; private final NullEquality nullEquality; private final NaNEquality nanEquality; @@ -346,7 +346,7 @@ public boolean equals(Object other) { } } - public static final class MergeSetsAggregation extends Aggregation { + static final class MergeSetsAggregation extends Aggregation { private final NullEquality nullEquality; private final NaNEquality nanEquality; @@ -386,14 +386,6 @@ protected Aggregation(Kind kind) { this.kind = kind; } - /** - * Add a column to the Aggregation so it can be used on a specific column of data. - * @param columnIndex the index of the column to operate on. - */ - public AggregationOnColumn onColumn(int columnIndex) { - return new AggregationOnColumn(this, columnIndex); - } - /** * Get the native view of a ColumnVector that provides default values to be used for some window * aggregations when there is not enough data to do the computation. This really only happens @@ -431,7 +423,7 @@ static void close(long[] ptrs) { static native void close(long ptr); - public static class SumAggregation extends NoParamAggregation { + static class SumAggregation extends NoParamAggregation { private SumAggregation() { super(Kind.SUM); } @@ -440,11 +432,11 @@ private SumAggregation() { /** * Sum reduction. */ - public static SumAggregation sum() { + static SumAggregation sum() { return new SumAggregation(); } - public static class ProductAggregation extends NoParamAggregation { + static class ProductAggregation extends NoParamAggregation { private ProductAggregation() { super(Kind.PRODUCT); } @@ -453,11 +445,11 @@ private ProductAggregation() { /** * Product reduction. */ - public static ProductAggregation product() { + static ProductAggregation product() { return new ProductAggregation(); } - public static class MinAggregation extends NoParamAggregation { + static class MinAggregation extends NoParamAggregation { private MinAggregation() { super(Kind.MIN); } @@ -466,11 +458,11 @@ private MinAggregation() { /** * Min reduction. */ - public static MinAggregation min() { + static MinAggregation min() { return new MinAggregation(); } - public static class MaxAggregation extends NoParamAggregation { + static class MaxAggregation extends NoParamAggregation { private MaxAggregation() { super(Kind.MAX); } @@ -479,11 +471,11 @@ private MaxAggregation() { /** * Max reduction. */ - public static MaxAggregation max() { + static MaxAggregation max() { return new MaxAggregation(); } - public static class CountAggregation extends CountLikeAggregation { + static class CountAggregation extends CountLikeAggregation { private CountAggregation(NullPolicy nullPolicy) { super(Kind.COUNT, nullPolicy); } @@ -492,7 +484,7 @@ private CountAggregation(NullPolicy nullPolicy) { /** * Count number of valid, a.k.a. non-null, elements. */ - public static CountAggregation count() { + static CountAggregation count() { return count(NullPolicy.EXCLUDE); } @@ -501,11 +493,11 @@ public static CountAggregation count() { * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values * should be counted. */ - public static CountAggregation count(NullPolicy nullPolicy) { + static CountAggregation count(NullPolicy nullPolicy) { return new CountAggregation(nullPolicy); } - public static class AnyAggregation extends NoParamAggregation { + static class AnyAggregation extends NoParamAggregation { private AnyAggregation() { super(Kind.ANY); } @@ -516,11 +508,11 @@ private AnyAggregation() { * if any of the elements in the range are true or non-zero, otherwise produces a false or 0. * Null values are skipped. */ - public static AnyAggregation any() { + static AnyAggregation any() { return new AnyAggregation(); } - public static class AllAggregation extends NoParamAggregation { + static class AllAggregation extends NoParamAggregation { private AllAggregation() { super(Kind.ALL); } @@ -531,12 +523,11 @@ private AllAggregation() { * the range are true or non-zero, otherwise produces a false or 0. * Null values are skipped. */ - public static AllAggregation all() { + static AllAggregation all() { return new AllAggregation(); } - - public static class SumOfSquaresAggregation extends NoParamAggregation { + static class SumOfSquaresAggregation extends NoParamAggregation { private SumOfSquaresAggregation() { super(Kind.SUM_OF_SQUARES); } @@ -545,11 +536,11 @@ private SumOfSquaresAggregation() { /** * Sum of squares reduction. */ - public static SumOfSquaresAggregation sumOfSquares() { + static SumOfSquaresAggregation sumOfSquares() { return new SumOfSquaresAggregation(); } - public static class MeanAggregation extends NoParamAggregation { + static class MeanAggregation extends NoParamAggregation { private MeanAggregation() { super(Kind.MEAN); } @@ -558,11 +549,11 @@ private MeanAggregation() { /** * Arithmetic mean reduction. */ - public static MeanAggregation mean() { + static MeanAggregation mean() { return new MeanAggregation(); } - public static class M2Aggregation extends NoParamAggregation { + static class M2Aggregation extends NoParamAggregation { private M2Aggregation() { super(Kind.M2); } @@ -571,11 +562,11 @@ private M2Aggregation() { /** * Sum of square of differences from mean. */ - public static M2Aggregation M2() { + static M2Aggregation M2() { return new M2Aggregation(); } - public static class VarianceAggregation extends DdofAggregation { + static class VarianceAggregation extends DdofAggregation { private VarianceAggregation(int ddof) { super(Kind.VARIANCE, ddof); } @@ -584,7 +575,7 @@ private VarianceAggregation(int ddof) { /** * Variance aggregation with 1 as the delta degrees of freedom. */ - public static VarianceAggregation variance() { + static VarianceAggregation variance() { return variance(1); } @@ -593,12 +584,12 @@ public static VarianceAggregation variance() { * @param ddof delta degrees of freedom. The divisor used in calculation of variance is * N - ddof, where N is the population size. */ - public static VarianceAggregation variance(int ddof) { + static VarianceAggregation variance(int ddof) { return new VarianceAggregation(ddof); } - public static class StandardDeviationAggregation extends DdofAggregation { + static class StandardDeviationAggregation extends DdofAggregation { private StandardDeviationAggregation(int ddof) { super(Kind.STD, ddof); } @@ -607,7 +598,7 @@ private StandardDeviationAggregation(int ddof) { /** * Standard deviation aggregation with 1 as the delta degrees of freedom. */ - public static StandardDeviationAggregation standardDeviation() { + static StandardDeviationAggregation standardDeviation() { return standardDeviation(1); } @@ -616,11 +607,11 @@ public static StandardDeviationAggregation standardDeviation() { * @param ddof delta degrees of freedom. The divisor used in calculation of std is * N - ddof, where N is the population size. */ - public static StandardDeviationAggregation standardDeviation(int ddof) { + static StandardDeviationAggregation standardDeviation(int ddof) { return new StandardDeviationAggregation(ddof); } - public static class MedianAggregation extends NoParamAggregation { + static class MedianAggregation extends NoParamAggregation { private MedianAggregation() { super(Kind.MEDIAN); } @@ -629,25 +620,25 @@ private MedianAggregation() { /** * Median reduction. */ - public static MedianAggregation median() { + static MedianAggregation median() { return new MedianAggregation(); } /** * Aggregate to compute the specified quantiles. Uses linear interpolation by default. */ - public static QuantileAggregation quantile(double ... quantiles) { + static QuantileAggregation quantile(double ... quantiles) { return quantile(QuantileMethod.LINEAR, quantiles); } /** * Aggregate to compute various quantiles. */ - public static QuantileAggregation quantile(QuantileMethod method, double ... quantiles) { + static QuantileAggregation quantile(QuantileMethod method, double ... quantiles) { return new QuantileAggregation(method, quantiles); } - public static class ArgMaxAggregation extends NoParamAggregation { + static class ArgMaxAggregation extends NoParamAggregation { private ArgMaxAggregation() { super(Kind.ARGMAX); } @@ -659,11 +650,11 @@ private ArgMaxAggregation() { * prior to doing the aggregation. This would result in an index into the sorted data being * returned. */ - public static ArgMaxAggregation argMax() { + static ArgMaxAggregation argMax() { return new ArgMaxAggregation(); } - public static class ArgMinAggregation extends NoParamAggregation { + static class ArgMinAggregation extends NoParamAggregation { private ArgMinAggregation() { super(Kind.ARGMIN); } @@ -675,11 +666,11 @@ private ArgMinAggregation() { * prior to doing the aggregation. This would result in an index into the sorted data being * returned. */ - public static ArgMinAggregation argMin() { + static ArgMinAggregation argMin() { return new ArgMinAggregation(); } - public static class NuniqueAggregation extends CountLikeAggregation { + static class NuniqueAggregation extends CountLikeAggregation { private NuniqueAggregation(NullPolicy nullPolicy) { super(Kind.NUNIQUE, nullPolicy); } @@ -688,7 +679,7 @@ private NuniqueAggregation(NullPolicy nullPolicy) { /** * Number of unique, non-null, elements. */ - public static NuniqueAggregation nunique() { + static NuniqueAggregation nunique() { return nunique(NullPolicy.EXCLUDE); } @@ -698,7 +689,7 @@ public static NuniqueAggregation nunique() { * compare as equal so multiple null values in a range would all only * increase the count by 1. */ - public static NuniqueAggregation nunique(NullPolicy nullPolicy) { + static NuniqueAggregation nunique(NullPolicy nullPolicy) { return new NuniqueAggregation(nullPolicy); } @@ -707,7 +698,7 @@ public static NuniqueAggregation nunique(NullPolicy nullPolicy) { * @param offset the offset to look at. Negative numbers go from the end of the group. Any * value outside of the group range results in a null. */ - public static NthAggregation nth(int offset) { + static NthAggregation nth(int offset) { return nth(offset, NullPolicy.INCLUDE); } @@ -718,7 +709,7 @@ public static NthAggregation nth(int offset) { * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they * should be skipped. */ - public static NthAggregation nth(int offset, NullPolicy nullPolicy) { + static NthAggregation nth(int offset, NullPolicy nullPolicy) { return new NthAggregation(offset, nullPolicy); } @@ -735,7 +726,7 @@ static RowNumberAggregation rowNumber() { return new RowNumberAggregation(); } - public static class RankAggregation extends NoParamAggregation { + static class RankAggregation extends NoParamAggregation { private RankAggregation() { super(Kind.RANK); } @@ -744,11 +735,11 @@ private RankAggregation() { /** * Get the row's ranking. */ - public static RankAggregation rank() { + static RankAggregation rank() { return new RankAggregation(); } - public static class DenseRankAggregation extends NoParamAggregation { + static class DenseRankAggregation extends NoParamAggregation { private DenseRankAggregation() { super(Kind.DENSE_RANK); } @@ -757,14 +748,14 @@ private DenseRankAggregation() { /** * Get the row's dense ranking. */ - public static DenseRankAggregation denseRank() { + static DenseRankAggregation denseRank() { return new DenseRankAggregation(); } /** * Collect the values into a list. Nulls will be skipped. */ - public static CollectListAggregation collectList() { + static CollectListAggregation collectList() { return collectList(NullPolicy.EXCLUDE); } @@ -773,7 +764,7 @@ public static CollectListAggregation collectList() { * * @param nullPolicy Indicates whether to include/exclude nulls during collection. */ - public static CollectListAggregation collectList(NullPolicy nullPolicy) { + static CollectListAggregation collectList(NullPolicy nullPolicy) { return new CollectListAggregation(nullPolicy); } @@ -781,7 +772,7 @@ public static CollectListAggregation collectList(NullPolicy nullPolicy) { * Collect the values into a set. All null values will be excluded, and all nan values are regarded as * unique instances. */ - public static CollectSetAggregation collectSet() { + static CollectSetAggregation collectSet() { return collectSet(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); } @@ -792,11 +783,11 @@ public static CollectSetAggregation collectSet() { * @param nullEquality Flag to specify whether null entries within each list should be considered equal. * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. */ - public static CollectSetAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + static CollectSetAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality); } - public static final class MergeListsAggregation extends NoParamAggregation { + static final class MergeListsAggregation extends NoParamAggregation { private MergeListsAggregation() { super(Kind.MERGE_LISTS); } @@ -806,7 +797,7 @@ private MergeListsAggregation() { * Merge the partial lists produced by multiple CollectListAggregations. * NOTICE: The partial lists to be merged should NOT include any null list element (but can include null list entries). */ - public static MergeListsAggregation mergeLists() { + static MergeListsAggregation mergeLists() { return new MergeListsAggregation(); } @@ -814,7 +805,7 @@ public static MergeListsAggregation mergeLists() { * Merge the partial sets produced by multiple CollectSetAggregations. Each null/nan value will be regarded as * a unique instance. */ - public static MergeSetsAggregation mergeSets() { + static MergeSetsAggregation mergeSets() { return mergeSets(NullEquality.UNEQUAL, NaNEquality.UNEQUAL); } @@ -824,7 +815,7 @@ public static MergeSetsAggregation mergeSets() { * @param nullEquality Flag to specify whether null entries within each list should be considered equal. * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. */ - public static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) { + static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) { return new MergeSetsAggregation(nullEquality, nanEquality); } @@ -869,7 +860,7 @@ private MergeM2Aggregation() { /** * Merge the partial M2 values produced by multiple instances of M2Aggregation. */ - public static MergeM2Aggregation mergeM2() { + static MergeM2Aggregation mergeM2() { return new MergeM2Aggregation(); } diff --git a/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java similarity index 74% rename from java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java rename to java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java index 2d9364b9705..43c8dbe888e 100644 --- a/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java @@ -19,13 +19,13 @@ package ai.rapids.cudf; /** - * An Aggregation for a specific column in a table. + * A GroupByAggregation for a specific column in a table. */ -public class AggregationOnColumn { - protected final Aggregation wrapped; +public class GroupByAggregationOnColumn { + protected final GroupByAggregation wrapped; protected final int columnIndex; - AggregationOnColumn(Aggregation wrapped, int columnIndex) { + GroupByAggregationOnColumn(GroupByAggregation wrapped, int columnIndex) { this.wrapped = wrapped; this.columnIndex = columnIndex; } @@ -34,7 +34,7 @@ public int getColumnIndex() { return columnIndex; } - Aggregation getWrapped() { + GroupByAggregation getWrapped() { return wrapped; } @@ -47,8 +47,8 @@ public int hashCode() { public boolean equals(Object other) { if (other == this) { return true; - } else if (other instanceof AggregationOnColumn) { - AggregationOnColumn o = (AggregationOnColumn) other; + } else if (other instanceof GroupByAggregationOnColumn) { + GroupByAggregationOnColumn o = (GroupByAggregationOnColumn) other; return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; } return false; diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 746bef6e939..360bb5c7467 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2456,7 +2456,7 @@ public static final class GroupByOperation { * 1, 2 * 2, 1 ==> aggregated count */ - public Table aggregate(AggregationOnColumn... aggregates) { + public Table aggregate(GroupByAggregationOnColumn... aggregates) { assert aggregates != null; // To improve performance and memory we want to remove duplicate operations @@ -2469,9 +2469,9 @@ public Table aggregate(AggregationOnColumn... aggregates) { int keysLength = operation.indices.length; int totalOps = 0; for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) { - AggregationOnColumn agg = aggregates[outputIndex]; + GroupByAggregationOnColumn agg = aggregates[outputIndex]; ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps()); - totalOps += ops.add(agg.getWrapped(), outputIndex + keysLength); + totalOps += ops.add(agg.getWrapped().getWrapped(), outputIndex + keysLength); } int[] aggColumnIndexes = new int[totalOps]; long[] aggOperationInstances = new long[totalOps]; diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index b6b8f197630..1b2ed1ad0b8 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2957,7 +2957,7 @@ void testGroupByUniqueCount() { .build()) { try (Table t3 = t1 .groupBy(0, 1) - .aggregate(Aggregation.nunique().onColumn(0)); + .aggregate(GroupByAggregation.nunique().onColumn(0)); Table sorted = t3.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() .column( "1", "1", "1", "1") @@ -2978,7 +2978,7 @@ void testGroupByUniqueCountNulls() { .build()) { try (Table t3 = t1 .groupBy(0, 1) - .aggregate(Aggregation.nunique(NullPolicy.INCLUDE).onColumn(0)); + .aggregate(GroupByAggregation.nunique(NullPolicy.INCLUDE).onColumn(0)); Table sorted = t3.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() .column( "1", "1", "1", "1") @@ -2997,7 +2997,7 @@ void testGroupByCount() { .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0) .build()) { try (Table t3 = t1.groupBy(0, 1) - .aggregate(Aggregation.count().onColumn(0)); + .aggregate(GroupByAggregation.count().onColumn(0)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -4784,9 +4784,9 @@ void testGroupByCountWithNulls() { .column( 1, 1, 1, null, 1, 1) .build()) { try (Table tmp = t1.groupBy(0).aggregate( - Aggregation.count().onColumn(1), - Aggregation.count().onColumn(2), - Aggregation.count().onColumn(3)); + GroupByAggregation.count().onColumn(1), + GroupByAggregation.count().onColumn(2), + GroupByAggregation.count().onColumn(3)); Table t3 = tmp.orderBy(OrderByArg.asc(0, true)); HostColumnVector groupCol = t3.getColumn(0).copyToHost(); HostColumnVector countCol = t3.getColumn(1).copyToHost(); @@ -4824,10 +4824,10 @@ void testGroupByCountWithNullsIncluded() { .column( 1, 1, 1, null, 1, 1) .build()) { try (Table tmp = t1.groupBy(0).aggregate( - Aggregation.count(NullPolicy.INCLUDE).onColumn(1), - Aggregation.count(NullPolicy.INCLUDE).onColumn(2), - Aggregation.count(NullPolicy.INCLUDE).onColumn(3), - Aggregation.count().onColumn(3)); + GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(1), + GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(2), + GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(3), + GroupByAggregation.count().onColumn(3)); Table t3 = tmp.orderBy(OrderByArg.asc(0, true)); HostColumnVector groupCol = t3.getColumn(0).copyToHost(); HostColumnVector countCol = t3.getColumn(1).copyToHost(); @@ -4875,9 +4875,9 @@ void testGroupByCountWithCollapsingNulls() { .build(); try (Table tmp = t1.groupBy(options, 0).aggregate( - Aggregation.count().onColumn(1), - Aggregation.count().onColumn(2), - Aggregation.count().onColumn(3)); + GroupByAggregation.count().onColumn(1), + GroupByAggregation.count().onColumn(2), + GroupByAggregation.count().onColumn(3)); Table t3 = tmp.orderBy(OrderByArg.asc(0, true)); HostColumnVector groupCol = t3.getColumn(0).copyToHost(); HostColumnVector countCol = t3.getColumn(1).copyToHost(); @@ -4908,7 +4908,7 @@ void testGroupByMax() { .column( 1, 3, 3, 5, 5, 0) .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.max().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.max().onColumn(2)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -4943,7 +4943,7 @@ void testGroupByArgMax() { .column(17.0, 14.0, 14.0, 17.0, 17.1, 17.0) .build()) { try (Table t3 = t1.groupBy(0, 1) - .aggregate(Aggregation.argMax().onColumn(2)); + .aggregate(GroupByAggregation.argMax().onColumn(2)); Table sorted = t3 .orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() @@ -4965,7 +4965,7 @@ void testGroupByArgMin() { .column(17.0, 14.0, 14.0, 17.0, 17.1, 17.0) .build()) { try (Table t3 = t1.groupBy(0, 1) - .aggregate(Aggregation.argMin().onColumn(2)); + .aggregate(GroupByAggregation.argMin().onColumn(2)); Table sorted = t3 .orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() @@ -4983,7 +4983,7 @@ void testGroupByMinBool() { try (Table t1 = new Table.TestBuilder() .column(true, null, false, true, null, null) .column( 1, 1, 2, 2, 3, 3).build(); - Table other = t1.groupBy(1).aggregate(Aggregation.min().onColumn(0)); + Table other = t1.groupBy(1).aggregate(GroupByAggregation.min().onColumn(0)); Table ordered = other.orderBy(OrderByArg.asc(0)); Table expected = new Table.TestBuilder() .column(1, 2, 3) @@ -4998,7 +4998,7 @@ void testGroupByMaxBool() { try (Table t1 = new Table.TestBuilder() .column(false, null, false, true, null, null) .column( 1, 1, 2, 2, 3, 3).build(); - Table other = t1.groupBy(1).aggregate(Aggregation.max().onColumn(0)); + Table other = t1.groupBy(1).aggregate(GroupByAggregation.max().onColumn(0)); Table ordered = other.orderBy(OrderByArg.asc(0)); Table expected = new Table.TestBuilder() .column(1, 2, 3) @@ -5025,12 +5025,12 @@ void testGroupByDuplicateAggregates() { .column( 1, 2, 2, 1).build()) { try (Table t3 = t1.groupBy(0, 1) .aggregate( - Aggregation.max().onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.max().onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.count().onColumn(1)); + GroupByAggregation.max().onColumn(2), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.max().onColumn(2), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.count().onColumn(1)); Table t4 = t3.orderBy(OrderByArg.asc(2))) { // verify t4 assertEquals(4, t4.getRowCount()); @@ -5053,7 +5053,7 @@ void testGroupByMin() { .column( 1, 3, 3, 5, 5, 0) .column( 12, 14, 13, 17, 17, 17) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.min().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.min().onColumn(2)); HostColumnVector aggOut0 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -5088,7 +5088,7 @@ void testGroupBySum() { .column( 1, 3, 3, 5, 5, 0) .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.sum().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.sum().onColumn(2)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -5121,7 +5121,7 @@ void testGroupByM2() { try (Table input = new Table.TestBuilder().column(1, 2, 3, 1, 2, 2, 1, 3, 3, 2) .column(0, 1, -2, 3, -4, -5, -6, 7, -8, 9) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3) .column(42.0, 122.75, 114.0) @@ -5134,7 +5134,7 @@ void testGroupByM2() { try (Table input = new Table.TestBuilder().column(1, 2, 5, 3, 4, 5, 2, 3, 2, 5) .column(0, null, null, 2, 3, null, 5, 6, 7, null) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3, 4, 5) .column(0.0, 2.0, 8.0, 0.0, null) @@ -5146,7 +5146,7 @@ void testGroupByM2() { try (Table input = new Table.TestBuilder().column(4, 3, 1, 2, 3, 1, 2, 2, 1, null, 3, 2, 4, 4) .column(null, null, 0.0, 1.0, 2.0, 3.0, 4.0, Double.NaN, 6.0, 7.0, 8.0, 9.0, 10.0, Double.NaN) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3, 4, null) .column(18.0, Double.NaN, 18.0, Double.NaN, 0.0) @@ -5179,7 +5179,7 @@ void testGroupByM2() { Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3, 4, 5) .column(Double.NaN, Double.NaN, Double.NaN, Double.NaN, 12.5) @@ -5237,7 +5237,7 @@ void testGroupByMergeM2() { partialResults3, partialResults4); Table finalResults = concatenatedResults.groupBy(0).aggregate( - Aggregation.mergeM2().onColumn(1)) + GroupByAggregation.mergeM2().onColumn(1)) ) { assertTablesAreEqual(expected, finalResults); } @@ -5255,7 +5255,7 @@ void testGroupByFirstExcludeNulls() { .column(13, 14) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(0, NullPolicy.EXCLUDE).onColumn(1))) { + GroupByAggregation.nth(0, NullPolicy.EXCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5271,7 +5271,7 @@ void testGroupByLastExcludeNulls() { .column(12, 15) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(-1, NullPolicy.EXCLUDE).onColumn(1))) { + GroupByAggregation.nth(-1, NullPolicy.EXCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5287,7 +5287,7 @@ void testGroupByFirstIncludeNulls() { .column(null, 14) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(0, NullPolicy.INCLUDE).onColumn(1))) { + GroupByAggregation.nth(0, NullPolicy.INCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5303,7 +5303,7 @@ void testGroupByLastIncludeNulls() { .column(12, null) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(-1, NullPolicy.INCLUDE).onColumn(1))) { + GroupByAggregation.nth(-1, NullPolicy.INCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5314,7 +5314,7 @@ void testGroupByAvg() { .column( 1, 3, 3, 5, 5, 0) .column(12, 14, 13, 1, 17, 17) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.mean().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.mean().onColumn(2)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -5349,11 +5349,11 @@ void testMultiAgg() { .column( 3, 1, 7, -1, 9, 0) .build()) { try (Table t2 = t1.groupBy(0, 1).aggregate( - Aggregation.count().onColumn(0), - Aggregation.max().onColumn(3), - Aggregation.min().onColumn(2), - Aggregation.mean().onColumn(2), - Aggregation.sum().onColumn(2)); + GroupByAggregation.count().onColumn(0), + GroupByAggregation.max().onColumn(3), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.mean().onColumn(2), + GroupByAggregation.sum().onColumn(2)); HostColumnVector countOut = t2.getColumn(2).copyToHost(); HostColumnVector maxOut = t2.getColumn(3).copyToHost(); HostColumnVector minOut = t2.getColumn(4).copyToHost(); @@ -5419,7 +5419,7 @@ void testSumWithStrings() { .column(5289L, 5203L, 5303L, 5206L) .build(); Table result = t.groupBy(0).aggregate( - Aggregation.sum().onColumn(1)); + GroupByAggregation.sum().onColumn(1)); Table expected = new Table.TestBuilder() .column("1-URGENT", "3-MEDIUM") .column(5289L + 5303L, 5203L + 5206L) @@ -5517,7 +5517,7 @@ void testGroupByCollectListIncludeNulls() { Arrays.asList(0)) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.collectList(NullPolicy.INCLUDE).onColumn(1))) { + GroupByAggregation.collectList(NullPolicy.INCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5563,8 +5563,8 @@ void testGroupByMergeLists() { Arrays.asList(new StructData(333, "s333"), new StructData(222, "s222"), new StructData(111, "s111")), Arrays.asList(new StructData(222, "s222"), new StructData(444, "s444"))) .build(); - Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(1)); - Table retListOfStructs = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(2))) { + Table retListOfInts = input.groupBy(0).aggregate(GroupByAggregation.mergeLists().onColumn(1)); + Table retListOfStructs = input.groupBy(0).aggregate(GroupByAggregation.mergeLists().onColumn(2))) { assertTablesAreEqual(expectedListOfInts, retListOfInts); assertTablesAreEqual(expectedListOfStructs, retListOfStructs); } @@ -5573,7 +5573,7 @@ void testGroupByMergeLists() { @Test void testGroupByCollectSetIncludeNulls() { // test with null unequal and nan unequal - Aggregation collectSet = Aggregation.collectSet(NullPolicy.INCLUDE, + GroupByAggregation collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) @@ -5589,7 +5589,7 @@ void testGroupByCollectSetIncludeNulls() { assertTablesAreEqual(expected, found); } // test with null equal and nan unequal - collectSet = Aggregation.collectSet(NullPolicy.INCLUDE, + collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) @@ -5610,7 +5610,7 @@ void testGroupByCollectSetIncludeNulls() { assertTablesAreEqual(expected, found); } // test with null equal and nan equal - collectSet = Aggregation.collectSet(NullPolicy.INCLUDE, + collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) @@ -5671,10 +5671,10 @@ void testGroupByMergeSets() { Arrays.asList(1e-3, 1e3, Double.NaN), Arrays.asList()) .build(); - Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(1)); - Table retListOfDoubles = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(2)); + Table retListOfInts = input.groupBy(0).aggregate(GroupByAggregation.mergeSets().onColumn(1)); + Table retListOfDoubles = input.groupBy(0).aggregate(GroupByAggregation.mergeSets().onColumn(2)); Table retListOfDoublesNaNEq = input.groupBy(0).aggregate( - Aggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2))) { + GroupByAggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2))) { assertTablesAreEqual(expectedListOfInts, retListOfInts); assertTablesAreEqual(expectedListOfDoubles, retListOfDoubles); assertTablesAreEqual(expectedListOfDoublesNaNEq, retListOfDoublesNaNEq); From 023cfdd0402e605f73daa151d7e91b047040fecb Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 2 Aug 2021 07:11:36 -0500 Subject: [PATCH 6/7] Group by aggregation file missed --- .../ai/rapids/cudf/GroupByAggregation.java | 296 ++++++++++++++++++ 1 file changed, 296 insertions(+) create mode 100644 java/src/main/java/ai/rapids/cudf/GroupByAggregation.java diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java new file mode 100644 index 00000000000..948faf65c3c --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -0,0 +1,296 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a reduce. + */ +public class GroupByAggregation { + private final Aggregation wrapped; + + private GroupByAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + Aggregation getWrapped() { + return wrapped; + } + + + /** + * Add a column to the Aggregation so it can be used on a specific column of data. + * @param columnIndex the index of the column to operate on. + */ + public GroupByAggregationOnColumn onColumn(int columnIndex) { + return new GroupByAggregationOnColumn(this, columnIndex); + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof GroupByAggregation) { + GroupByAggregation o = (GroupByAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Count number of valid, a.k.a. non-null, elements. + */ + public static GroupByAggregation count() { + return new GroupByAggregation(Aggregation.count()); + } + + /** + * Count number of elements. + * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values + * should be counted. + */ + public static GroupByAggregation count(NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.count(nullPolicy)); + } + + /** + * Sum Aggregation + */ + public static GroupByAggregation sum() { + return new GroupByAggregation(Aggregation.sum()); + } + + /** + * Product Aggregation. + */ + public static GroupByAggregation product() { + return new GroupByAggregation(Aggregation.product()); + } + + + /** + * Index of max element. Please note that when using this aggregation if the + * data is not already sorted by the grouping keys it may be automatically sorted + * prior to doing the aggregation. This would result in an index into the sorted data being + * returned. + */ + public static GroupByAggregation argMax() { + return new GroupByAggregation(Aggregation.argMax()); + } + + /** + * Index of min element. Please note that when using this aggregation if the + * data is not already sorted by the grouping keys it may be automatically sorted + * prior to doing the aggregation. This would result in an index into the sorted data being + * returned. + */ + public static GroupByAggregation argMin() { + return new GroupByAggregation(Aggregation.argMin()); + } + + /** + * Min Aggregation + */ + public static GroupByAggregation min() { + return new GroupByAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static GroupByAggregation max() { + return new GroupByAggregation(Aggregation.max()); + } + + /** + * Arithmetic mean reduction. + */ + public static GroupByAggregation mean() { + return new GroupByAggregation(Aggregation.mean()); + } + + /** + * Sum of square of differences from mean. + */ + public static GroupByAggregation M2() { + return new GroupByAggregation(Aggregation.M2()); + } + + /** + * Variance aggregation with 1 as the delta degrees of freedom. + */ + public static GroupByAggregation variance() { + return new GroupByAggregation(Aggregation.variance()); + } + + /** + * Variance aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of variance is + * N - ddof, where N is the population size. + */ + public static GroupByAggregation variance(int ddof) { + return new GroupByAggregation(Aggregation.variance(ddof)); + } + + /** + * Standard deviation aggregation with 1 as the delta degrees of freedom. + */ + public static GroupByAggregation standardDeviation() { + return new GroupByAggregation(Aggregation.standardDeviation()); + } + + /** + * Standard deviation aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of std is + * N - ddof, where N is the population size. + */ + public static GroupByAggregation standardDeviation(int ddof) { + return new GroupByAggregation(Aggregation.standardDeviation(ddof)); + } + + /** + * Aggregate to compute the specified quantiles. Uses linear interpolation by default. + */ + public static GroupByAggregation quantile(double ... quantiles) { + return new GroupByAggregation(Aggregation.quantile(quantiles)); + } + + /** + * Aggregate to compute various quantiles. + */ + public static GroupByAggregation quantile(QuantileMethod method, double ... quantiles) { + return new GroupByAggregation(Aggregation.quantile(method, quantiles)); + } + + /** + * Median reduction. + */ + public static GroupByAggregation median() { + return new GroupByAggregation(Aggregation.median()); + } + + /** + * Number of unique, non-null, elements. + */ + public static GroupByAggregation nunique() { + return new GroupByAggregation(Aggregation.nunique()); + } + + /** + * Number of unique elements. + * @param nullPolicy INCLUDE if nulls should be counted else EXCLUDE. If nulls are counted they + * compare as equal so multiple null values in a range would all only + * increase the count by 1. + */ + public static GroupByAggregation nunique(NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.nunique(nullPolicy)); + } + + /** + * Get the nth, non-null, element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + */ + public static GroupByAggregation nth(int offset) { + return new GroupByAggregation(Aggregation.nth(offset)); + } + + /** + * Get the nth element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they + * should be skipped. + */ + public static GroupByAggregation nth(int offset, NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.nth(offset, nullPolicy)); + } + + /** + * Collect the values into a list. Nulls will be skipped. + */ + public static GroupByAggregation collectList() { + return new GroupByAggregation(Aggregation.collectList()); + } + + /** + * Collect the values into a list. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + */ + public static GroupByAggregation collectList(NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.collectList(nullPolicy)); + } + + /** + * Collect the values into a set. All null values will be excluded, and all nan values are regarded as + * unique instances. + */ + public static GroupByAggregation collectSet() { + return new GroupByAggregation(Aggregation.collectSet()); + } + + /** + * Collect the values into a set. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static GroupByAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + return new GroupByAggregation(Aggregation.collectSet(nullPolicy, nullEquality, nanEquality)); + } + + /** + * Merge the partial lists produced by multiple CollectListAggregations. + * NOTICE: The partial lists to be merged should NOT include any null list element (but can include null list entries). + */ + public static GroupByAggregation mergeLists() { + return new GroupByAggregation(Aggregation.mergeLists()); + } + + /** + * Merge the partial sets produced by multiple CollectSetAggregations. Each null/nan value will be regarded as + * a unique instance. + */ + public static GroupByAggregation mergeSets() { + return new GroupByAggregation(Aggregation.mergeSets()); + } + + /** + * Merge the partial sets produced by multiple CollectSetAggregations. + * + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static GroupByAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) { + return new GroupByAggregation(Aggregation.mergeSets(nullEquality, nanEquality)); + } + + /** + * Merge the partial M2 values produced by multiple instances of M2Aggregation. + */ + public static GroupByAggregation mergeM2() { + return new GroupByAggregation(Aggregation.mergeM2()); + } +} From bfe2338337a6a700f15644a91c17d8b7e8c72eba Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 2 Aug 2021 13:08:05 -0500 Subject: [PATCH 7/7] Addressed review comments --- .../main/java/ai/rapids/cudf/Aggregation.java | 44 +++++++++---------- .../ai/rapids/cudf/AggregationOverWindow.java | 2 +- .../ai/rapids/cudf/GroupByAggregation.java | 2 +- .../cudf/GroupByAggregationOnColumn.java | 2 +- .../rapids/cudf/GroupByScanAggregation.java | 2 +- .../cudf/GroupByScanAggregationOnColumn.java | 2 +- .../ai/rapids/cudf/ReductionAggregation.java | 2 +- .../ai/rapids/cudf/RollingAggregation.java | 2 +- .../cudf/RollingAggregationOnColumn.java | 2 +- .../java/ai/rapids/cudf/ScanAggregation.java | 2 +- 10 files changed, 31 insertions(+), 31 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 0480dc465ab..734d9cb5694 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -194,7 +194,7 @@ public boolean equals(Object other) { } } - private static class QuantileAggregation extends Aggregation { + private static final class QuantileAggregation extends Aggregation { private final QuantileMethod method; private final double[] quantiles; @@ -423,7 +423,7 @@ static void close(long[] ptrs) { static native void close(long ptr); - static class SumAggregation extends NoParamAggregation { + static final class SumAggregation extends NoParamAggregation { private SumAggregation() { super(Kind.SUM); } @@ -436,7 +436,7 @@ static SumAggregation sum() { return new SumAggregation(); } - static class ProductAggregation extends NoParamAggregation { + static final class ProductAggregation extends NoParamAggregation { private ProductAggregation() { super(Kind.PRODUCT); } @@ -449,7 +449,7 @@ static ProductAggregation product() { return new ProductAggregation(); } - static class MinAggregation extends NoParamAggregation { + static final class MinAggregation extends NoParamAggregation { private MinAggregation() { super(Kind.MIN); } @@ -462,7 +462,7 @@ static MinAggregation min() { return new MinAggregation(); } - static class MaxAggregation extends NoParamAggregation { + static final class MaxAggregation extends NoParamAggregation { private MaxAggregation() { super(Kind.MAX); } @@ -475,7 +475,7 @@ static MaxAggregation max() { return new MaxAggregation(); } - static class CountAggregation extends CountLikeAggregation { + static final class CountAggregation extends CountLikeAggregation { private CountAggregation(NullPolicy nullPolicy) { super(Kind.COUNT, nullPolicy); } @@ -497,7 +497,7 @@ static CountAggregation count(NullPolicy nullPolicy) { return new CountAggregation(nullPolicy); } - static class AnyAggregation extends NoParamAggregation { + static final class AnyAggregation extends NoParamAggregation { private AnyAggregation() { super(Kind.ANY); } @@ -512,7 +512,7 @@ static AnyAggregation any() { return new AnyAggregation(); } - static class AllAggregation extends NoParamAggregation { + static final class AllAggregation extends NoParamAggregation { private AllAggregation() { super(Kind.ALL); } @@ -527,7 +527,7 @@ static AllAggregation all() { return new AllAggregation(); } - static class SumOfSquaresAggregation extends NoParamAggregation { + static final class SumOfSquaresAggregation extends NoParamAggregation { private SumOfSquaresAggregation() { super(Kind.SUM_OF_SQUARES); } @@ -540,7 +540,7 @@ static SumOfSquaresAggregation sumOfSquares() { return new SumOfSquaresAggregation(); } - static class MeanAggregation extends NoParamAggregation { + static final class MeanAggregation extends NoParamAggregation { private MeanAggregation() { super(Kind.MEAN); } @@ -553,7 +553,7 @@ static MeanAggregation mean() { return new MeanAggregation(); } - static class M2Aggregation extends NoParamAggregation { + static final class M2Aggregation extends NoParamAggregation { private M2Aggregation() { super(Kind.M2); } @@ -566,7 +566,7 @@ static M2Aggregation M2() { return new M2Aggregation(); } - static class VarianceAggregation extends DdofAggregation { + static final class VarianceAggregation extends DdofAggregation { private VarianceAggregation(int ddof) { super(Kind.VARIANCE, ddof); } @@ -589,7 +589,7 @@ static VarianceAggregation variance(int ddof) { } - static class StandardDeviationAggregation extends DdofAggregation { + static final class StandardDeviationAggregation extends DdofAggregation { private StandardDeviationAggregation(int ddof) { super(Kind.STD, ddof); } @@ -611,7 +611,7 @@ static StandardDeviationAggregation standardDeviation(int ddof) { return new StandardDeviationAggregation(ddof); } - static class MedianAggregation extends NoParamAggregation { + static final class MedianAggregation extends NoParamAggregation { private MedianAggregation() { super(Kind.MEDIAN); } @@ -638,7 +638,7 @@ static QuantileAggregation quantile(QuantileMethod method, double ... quantiles) return new QuantileAggregation(method, quantiles); } - static class ArgMaxAggregation extends NoParamAggregation { + static final class ArgMaxAggregation extends NoParamAggregation { private ArgMaxAggregation() { super(Kind.ARGMAX); } @@ -654,7 +654,7 @@ static ArgMaxAggregation argMax() { return new ArgMaxAggregation(); } - static class ArgMinAggregation extends NoParamAggregation { + static final class ArgMinAggregation extends NoParamAggregation { private ArgMinAggregation() { super(Kind.ARGMIN); } @@ -670,7 +670,7 @@ static ArgMinAggregation argMin() { return new ArgMinAggregation(); } - static class NuniqueAggregation extends CountLikeAggregation { + static final class NuniqueAggregation extends CountLikeAggregation { private NuniqueAggregation(NullPolicy nullPolicy) { super(Kind.NUNIQUE, nullPolicy); } @@ -713,7 +713,7 @@ static NthAggregation nth(int offset, NullPolicy nullPolicy) { return new NthAggregation(offset, nullPolicy); } - static class RowNumberAggregation extends NoParamAggregation { + static final class RowNumberAggregation extends NoParamAggregation { private RowNumberAggregation() { super(Kind.ROW_NUMBER); } @@ -726,7 +726,7 @@ static RowNumberAggregation rowNumber() { return new RowNumberAggregation(); } - static class RankAggregation extends NoParamAggregation { + static final class RankAggregation extends NoParamAggregation { private RankAggregation() { super(Kind.RANK); } @@ -739,7 +739,7 @@ static RankAggregation rank() { return new RankAggregation(); } - static class DenseRankAggregation extends NoParamAggregation { + static final class DenseRankAggregation extends NoParamAggregation { private DenseRankAggregation() { super(Kind.DENSE_RANK); } @@ -819,7 +819,7 @@ static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nan return new MergeSetsAggregation(nullEquality, nanEquality); } - static class LeadAggregation extends LeadLagAggregation { + static final class LeadAggregation extends LeadLagAggregation { private LeadAggregation(int offset, ColumnVector defaultOutput) { super(Kind.LEAD, offset, defaultOutput); } @@ -835,7 +835,7 @@ static LeadAggregation lead(int offset, ColumnVector defaultOutput) { return new LeadAggregation(offset, defaultOutput); } - static class LagAggregation extends LeadLagAggregation { + static final class LagAggregation extends LeadLagAggregation { private LagAggregation(int offset, ColumnVector defaultOutput) { super(Kind.LAG, offset, defaultOutput); } diff --git a/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java b/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java index 9a82eae65bf..d5544e01e7e 100644 --- a/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java +++ b/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java @@ -22,7 +22,7 @@ * An Aggregation instance that also holds a column number and window metadata so the aggregation * can be done over a specific window. */ -public class AggregationOverWindow { +public final class AggregationOverWindow { private final RollingAggregationOnColumn wrapped; protected final WindowOptions windowOptions; diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java index 948faf65c3c..dd2adf8bee8 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -21,7 +21,7 @@ /** * An aggregation that can be used for a reduce. */ -public class GroupByAggregation { +public final class GroupByAggregation { private final Aggregation wrapped; private GroupByAggregation(Aggregation wrapped) { diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java index 43c8dbe888e..c50cf3728f0 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java @@ -21,7 +21,7 @@ /** * A GroupByAggregation for a specific column in a table. */ -public class GroupByAggregationOnColumn { +public final class GroupByAggregationOnColumn { protected final GroupByAggregation wrapped; protected final int columnIndex; diff --git a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java index 97250a71486..219b6dde05d 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java @@ -21,7 +21,7 @@ /** * An aggregation that can be used for a grouped scan. */ -public class GroupByScanAggregation { +public final class GroupByScanAggregation { private final Aggregation wrapped; private GroupByScanAggregation(Aggregation wrapped) { diff --git a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java index 227cd58ae8c..75e4936e5b9 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java @@ -21,7 +21,7 @@ /** * A GroupByScanAggregation for a specific column in a table. */ -public class GroupByScanAggregationOnColumn { +public final class GroupByScanAggregationOnColumn { protected final GroupByScanAggregation wrapped; protected final int columnIndex; diff --git a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java index ad96b93c400..7eff85dcd0d 100644 --- a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java @@ -21,7 +21,7 @@ /** * An aggregation that can be used for a reduce. */ -public class ReductionAggregation { +public final class ReductionAggregation { private final Aggregation wrapped; private ReductionAggregation(Aggregation wrapped) { diff --git a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java index b7e56606fb5..07983f77aad 100644 --- a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java @@ -21,7 +21,7 @@ /** * An aggregation that can be used on rolling windows. */ -public class RollingAggregation { +public final class RollingAggregation { private final Aggregation wrapped; private RollingAggregation(Aggregation wrapped) { diff --git a/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java index 7fde7c30b3f..a6b1484aa71 100644 --- a/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java @@ -21,7 +21,7 @@ /** * A RollingAggregation for a specific column in a table. */ -public class RollingAggregationOnColumn { +public final class RollingAggregationOnColumn { protected final RollingAggregation wrapped; protected final int columnIndex; diff --git a/java/src/main/java/ai/rapids/cudf/ScanAggregation.java b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java index bd19546e5ef..08489562adc 100644 --- a/java/src/main/java/ai/rapids/cudf/ScanAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java @@ -21,7 +21,7 @@ /** * An aggregation that can be used for a scan. */ -public class ScanAggregation { +public final class ScanAggregation { private final Aggregation wrapped; private ScanAggregation(Aggregation wrapped) {