diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 49c6d2b6ffc..734d9cb5694 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -24,7 +24,7 @@ * Represents an aggregation operation. Please note that not all aggregations work, or even make * sense in all types of aggregation operations. */ -public abstract class Aggregation { +abstract class Aggregation { static { NativeDepsLoader.loadNativeDeps(); } @@ -65,7 +65,7 @@ enum Kind { M2(26), MERGE_M2(27), RANK(28), - DENSE_RANK(29);; + DENSE_RANK(29); final int nativeId; @@ -102,7 +102,7 @@ public boolean equals(Object other) { } } - public static final class NthAggregation extends Aggregation { + static final class NthAggregation extends Aggregation { private final int offset; private final NullPolicy nullPolicy; @@ -194,7 +194,7 @@ public boolean equals(Object other) { } } - private static class QuantileAggregation extends Aggregation { + private static final class QuantileAggregation extends Aggregation { private final QuantileMethod method; private final double[] quantiles; @@ -275,8 +275,7 @@ long getDefaultOutput() { } } - public static final class CollectListAggregation extends Aggregation - implements RollingAggregation { + static final class CollectListAggregation extends Aggregation { private final NullPolicy nullPolicy; private CollectListAggregation(NullPolicy nullPolicy) { @@ -306,8 +305,7 @@ public boolean equals(Object other) { } } - public static final class CollectSetAggregation extends Aggregation - implements RollingAggregation { + static final class CollectSetAggregation extends Aggregation { private final NullPolicy nullPolicy; private final NullEquality nullEquality; private final NaNEquality nanEquality; @@ -348,7 +346,7 @@ public boolean equals(Object other) { } } - public static final class MergeSetsAggregation extends Aggregation { + static final class MergeSetsAggregation extends Aggregation { private final NullEquality nullEquality; private final NaNEquality nanEquality; @@ -388,14 +386,6 @@ protected Aggregation(Kind kind) { this.kind = kind; } - /** - * Add a column to the Aggregation so it can be used on a specific column of data. - * @param columnIndex the index of the column to operate on. - */ - public AggregationOnColumn onColumn(int columnIndex) { - return new AggregationOnColumn((T)this, columnIndex); - } - /** * Get the native view of a ColumnVector that provides default values to be used for some window * aggregations when there is not enough data to do the computation. This really only happens @@ -433,8 +423,7 @@ static void close(long[] ptrs) { static native void close(long ptr); - public static class SumAggregation extends NoParamAggregation - implements RollingAggregation { + static final class SumAggregation extends NoParamAggregation { private SumAggregation() { super(Kind.SUM); } @@ -443,11 +432,11 @@ private SumAggregation() { /** * Sum reduction. */ - public static SumAggregation sum() { + static SumAggregation sum() { return new SumAggregation(); } - public static class ProductAggregation extends NoParamAggregation { + static final class ProductAggregation extends NoParamAggregation { private ProductAggregation() { super(Kind.PRODUCT); } @@ -456,12 +445,11 @@ private ProductAggregation() { /** * Product reduction. */ - public static ProductAggregation product() { + static ProductAggregation product() { return new ProductAggregation(); } - public static class MinAggregation extends NoParamAggregation - implements RollingAggregation { + static final class MinAggregation extends NoParamAggregation { private MinAggregation() { super(Kind.MIN); } @@ -470,12 +458,11 @@ private MinAggregation() { /** * Min reduction. */ - public static MinAggregation min() { + static MinAggregation min() { return new MinAggregation(); } - public static class MaxAggregation extends NoParamAggregation - implements RollingAggregation { + static final class MaxAggregation extends NoParamAggregation { private MaxAggregation() { super(Kind.MAX); } @@ -484,12 +471,11 @@ private MaxAggregation() { /** * Max reduction. */ - public static MaxAggregation max() { + static MaxAggregation max() { return new MaxAggregation(); } - public static class CountAggregation extends CountLikeAggregation - implements RollingAggregation { + static final class CountAggregation extends CountLikeAggregation { private CountAggregation(NullPolicy nullPolicy) { super(Kind.COUNT, nullPolicy); } @@ -498,7 +484,7 @@ private CountAggregation(NullPolicy nullPolicy) { /** * Count number of valid, a.k.a. non-null, elements. */ - public static CountAggregation count() { + static CountAggregation count() { return count(NullPolicy.EXCLUDE); } @@ -507,11 +493,11 @@ public static CountAggregation count() { * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values * should be counted. */ - public static CountAggregation count(NullPolicy nullPolicy) { + static CountAggregation count(NullPolicy nullPolicy) { return new CountAggregation(nullPolicy); } - public static class AnyAggregation extends NoParamAggregation { + static final class AnyAggregation extends NoParamAggregation { private AnyAggregation() { super(Kind.ANY); } @@ -522,11 +508,11 @@ private AnyAggregation() { * if any of the elements in the range are true or non-zero, otherwise produces a false or 0. * Null values are skipped. */ - public static AnyAggregation any() { + static AnyAggregation any() { return new AnyAggregation(); } - public static class AllAggregation extends NoParamAggregation { + static final class AllAggregation extends NoParamAggregation { private AllAggregation() { super(Kind.ALL); } @@ -537,12 +523,11 @@ private AllAggregation() { * the range are true or non-zero, otherwise produces a false or 0. * Null values are skipped. */ - public static AllAggregation all() { + static AllAggregation all() { return new AllAggregation(); } - - public static class SumOfSquaresAggregation extends NoParamAggregation { + static final class SumOfSquaresAggregation extends NoParamAggregation { private SumOfSquaresAggregation() { super(Kind.SUM_OF_SQUARES); } @@ -551,12 +536,11 @@ private SumOfSquaresAggregation() { /** * Sum of squares reduction. */ - public static SumOfSquaresAggregation sumOfSquares() { + static SumOfSquaresAggregation sumOfSquares() { return new SumOfSquaresAggregation(); } - public static class MeanAggregation extends NoParamAggregation - implements RollingAggregation{ + static final class MeanAggregation extends NoParamAggregation { private MeanAggregation() { super(Kind.MEAN); } @@ -565,11 +549,11 @@ private MeanAggregation() { /** * Arithmetic mean reduction. */ - public static MeanAggregation mean() { + static MeanAggregation mean() { return new MeanAggregation(); } - public static class M2Aggregation extends NoParamAggregation { + static final class M2Aggregation extends NoParamAggregation { private M2Aggregation() { super(Kind.M2); } @@ -578,11 +562,11 @@ private M2Aggregation() { /** * Sum of square of differences from mean. */ - public static M2Aggregation M2() { + static M2Aggregation M2() { return new M2Aggregation(); } - public static class VarianceAggregation extends DdofAggregation { + static final class VarianceAggregation extends DdofAggregation { private VarianceAggregation(int ddof) { super(Kind.VARIANCE, ddof); } @@ -591,7 +575,7 @@ private VarianceAggregation(int ddof) { /** * Variance aggregation with 1 as the delta degrees of freedom. */ - public static VarianceAggregation variance() { + static VarianceAggregation variance() { return variance(1); } @@ -600,12 +584,12 @@ public static VarianceAggregation variance() { * @param ddof delta degrees of freedom. The divisor used in calculation of variance is * N - ddof, where N is the population size. */ - public static VarianceAggregation variance(int ddof) { + static VarianceAggregation variance(int ddof) { return new VarianceAggregation(ddof); } - public static class StandardDeviationAggregation extends DdofAggregation { + static final class StandardDeviationAggregation extends DdofAggregation { private StandardDeviationAggregation(int ddof) { super(Kind.STD, ddof); } @@ -614,7 +598,7 @@ private StandardDeviationAggregation(int ddof) { /** * Standard deviation aggregation with 1 as the delta degrees of freedom. */ - public static StandardDeviationAggregation standardDeviation() { + static StandardDeviationAggregation standardDeviation() { return standardDeviation(1); } @@ -623,11 +607,11 @@ public static StandardDeviationAggregation standardDeviation() { * @param ddof delta degrees of freedom. The divisor used in calculation of std is * N - ddof, where N is the population size. */ - public static StandardDeviationAggregation standardDeviation(int ddof) { + static StandardDeviationAggregation standardDeviation(int ddof) { return new StandardDeviationAggregation(ddof); } - public static class MedianAggregation extends NoParamAggregation { + static final class MedianAggregation extends NoParamAggregation { private MedianAggregation() { super(Kind.MEDIAN); } @@ -636,26 +620,25 @@ private MedianAggregation() { /** * Median reduction. */ - public static MedianAggregation median() { + static MedianAggregation median() { return new MedianAggregation(); } /** * Aggregate to compute the specified quantiles. Uses linear interpolation by default. */ - public static QuantileAggregation quantile(double ... quantiles) { + static QuantileAggregation quantile(double ... quantiles) { return quantile(QuantileMethod.LINEAR, quantiles); } /** * Aggregate to compute various quantiles. */ - public static QuantileAggregation quantile(QuantileMethod method, double ... quantiles) { + static QuantileAggregation quantile(QuantileMethod method, double ... quantiles) { return new QuantileAggregation(method, quantiles); } - public static class ArgMaxAggregation extends NoParamAggregation - implements RollingAggregation{ + static final class ArgMaxAggregation extends NoParamAggregation { private ArgMaxAggregation() { super(Kind.ARGMAX); } @@ -667,12 +650,11 @@ private ArgMaxAggregation() { * prior to doing the aggregation. This would result in an index into the sorted data being * returned. */ - public static ArgMaxAggregation argMax() { + static ArgMaxAggregation argMax() { return new ArgMaxAggregation(); } - public static class ArgMinAggregation extends NoParamAggregation - implements RollingAggregation{ + static final class ArgMinAggregation extends NoParamAggregation { private ArgMinAggregation() { super(Kind.ARGMIN); } @@ -684,11 +666,11 @@ private ArgMinAggregation() { * prior to doing the aggregation. This would result in an index into the sorted data being * returned. */ - public static ArgMinAggregation argMin() { + static ArgMinAggregation argMin() { return new ArgMinAggregation(); } - public static class NuniqueAggregation extends CountLikeAggregation { + static final class NuniqueAggregation extends CountLikeAggregation { private NuniqueAggregation(NullPolicy nullPolicy) { super(Kind.NUNIQUE, nullPolicy); } @@ -697,7 +679,7 @@ private NuniqueAggregation(NullPolicy nullPolicy) { /** * Number of unique, non-null, elements. */ - public static NuniqueAggregation nunique() { + static NuniqueAggregation nunique() { return nunique(NullPolicy.EXCLUDE); } @@ -707,7 +689,7 @@ public static NuniqueAggregation nunique() { * compare as equal so multiple null values in a range would all only * increase the count by 1. */ - public static NuniqueAggregation nunique(NullPolicy nullPolicy) { + static NuniqueAggregation nunique(NullPolicy nullPolicy) { return new NuniqueAggregation(nullPolicy); } @@ -716,7 +698,7 @@ public static NuniqueAggregation nunique(NullPolicy nullPolicy) { * @param offset the offset to look at. Negative numbers go from the end of the group. Any * value outside of the group range results in a null. */ - public static NthAggregation nth(int offset) { + static NthAggregation nth(int offset) { return nth(offset, NullPolicy.INCLUDE); } @@ -727,12 +709,11 @@ public static NthAggregation nth(int offset) { * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they * should be skipped. */ - public static NthAggregation nth(int offset, NullPolicy nullPolicy) { + static NthAggregation nth(int offset, NullPolicy nullPolicy) { return new NthAggregation(offset, nullPolicy); } - public static class RowNumberAggregation extends NoParamAggregation - implements RollingAggregation{ + static final class RowNumberAggregation extends NoParamAggregation { private RowNumberAggregation() { super(Kind.ROW_NUMBER); } @@ -741,12 +722,11 @@ private RowNumberAggregation() { /** * Get the row number, only makes sense for a window operations. */ - public static RowNumberAggregation rowNumber() { + static RowNumberAggregation rowNumber() { return new RowNumberAggregation(); } - public static class RankAggregation extends NoParamAggregation - implements RollingAggregation{ + static final class RankAggregation extends NoParamAggregation { private RankAggregation() { super(Kind.RANK); } @@ -755,12 +735,11 @@ private RankAggregation() { /** * Get the row's ranking. */ - public static RankAggregation rank() { + static RankAggregation rank() { return new RankAggregation(); } - public static class DenseRankAggregation extends NoParamAggregation - implements RollingAggregation{ + static final class DenseRankAggregation extends NoParamAggregation { private DenseRankAggregation() { super(Kind.DENSE_RANK); } @@ -769,14 +748,14 @@ private DenseRankAggregation() { /** * Get the row's dense ranking. */ - public static DenseRankAggregation denseRank() { + static DenseRankAggregation denseRank() { return new DenseRankAggregation(); } /** * Collect the values into a list. Nulls will be skipped. */ - public static CollectListAggregation collectList() { + static CollectListAggregation collectList() { return collectList(NullPolicy.EXCLUDE); } @@ -785,7 +764,7 @@ public static CollectListAggregation collectList() { * * @param nullPolicy Indicates whether to include/exclude nulls during collection. */ - public static CollectListAggregation collectList(NullPolicy nullPolicy) { + static CollectListAggregation collectList(NullPolicy nullPolicy) { return new CollectListAggregation(nullPolicy); } @@ -793,7 +772,7 @@ public static CollectListAggregation collectList(NullPolicy nullPolicy) { * Collect the values into a set. All null values will be excluded, and all nan values are regarded as * unique instances. */ - public static CollectSetAggregation collectSet() { + static CollectSetAggregation collectSet() { return collectSet(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); } @@ -804,11 +783,11 @@ public static CollectSetAggregation collectSet() { * @param nullEquality Flag to specify whether null entries within each list should be considered equal. * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. */ - public static CollectSetAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + static CollectSetAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality); } - public static final class MergeListsAggregation extends NoParamAggregation { + static final class MergeListsAggregation extends NoParamAggregation { private MergeListsAggregation() { super(Kind.MERGE_LISTS); } @@ -818,7 +797,7 @@ private MergeListsAggregation() { * Merge the partial lists produced by multiple CollectListAggregations. * NOTICE: The partial lists to be merged should NOT include any null list element (but can include null list entries). */ - public static MergeListsAggregation mergeLists() { + static MergeListsAggregation mergeLists() { return new MergeListsAggregation(); } @@ -826,7 +805,7 @@ public static MergeListsAggregation mergeLists() { * Merge the partial sets produced by multiple CollectSetAggregations. Each null/nan value will be regarded as * a unique instance. */ - public static MergeSetsAggregation mergeSets() { + static MergeSetsAggregation mergeSets() { return mergeSets(NullEquality.UNEQUAL, NaNEquality.UNEQUAL); } @@ -836,58 +815,39 @@ public static MergeSetsAggregation mergeSets() { * @param nullEquality Flag to specify whether null entries within each list should be considered equal. * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. */ - public static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) { + static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) { return new MergeSetsAggregation(nullEquality, nanEquality); } - public static class LeadAggregation extends LeadLagAggregation - implements RollingAggregation { + static final class LeadAggregation extends LeadLagAggregation { private LeadAggregation(int offset, ColumnVector defaultOutput) { super(Kind.LEAD, offset, defaultOutput); } } - /** - * In a rolling window return the value offset entries ahead or null if it is outside of the - * window. - */ - public static LeadAggregation lead(int offset) { - return lead(offset, null); - } - /** * In a rolling window return the value offset entries ahead or the corresponding value from * defaultOutput if it is outside of the window. Note that this does not take any ownership of * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life * time of this aggregation operation. */ - public static LeadAggregation lead(int offset, ColumnVector defaultOutput) { + static LeadAggregation lead(int offset, ColumnVector defaultOutput) { return new LeadAggregation(offset, defaultOutput); } - public static class LagAggregation extends LeadLagAggregation - implements RollingAggregation{ + static final class LagAggregation extends LeadLagAggregation { private LagAggregation(int offset, ColumnVector defaultOutput) { super(Kind.LAG, offset, defaultOutput); } } - - /** - * In a rolling window return the value offset entries behind or null if it is outside of the - * window. - */ - public static LagAggregation lag(int offset) { - return lag(offset, null); - } - /** * In a rolling window return the value offset entries behind or the corresponding value from * defaultOutput if it is outside of the window. Note that this does not take any ownership of * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life * time of this aggregation operation. */ - public static LagAggregation lag(int offset, ColumnVector defaultOutput) { + static LagAggregation lag(int offset, ColumnVector defaultOutput) { return new LagAggregation(offset, defaultOutput); } @@ -900,7 +860,7 @@ private MergeM2Aggregation() { /** * Merge the partial M2 values produced by multiple instances of M2Aggregation. */ - public static MergeM2Aggregation mergeM2() { + static MergeM2Aggregation mergeM2() { return new MergeM2Aggregation(); } diff --git a/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java b/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java index abce287c9b0..d5544e01e7e 100644 --- a/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java +++ b/java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java @@ -22,12 +22,12 @@ * An Aggregation instance that also holds a column number and window metadata so the aggregation * can be done over a specific window. */ -public class AggregationOverWindow> - extends AggregationOnColumn { +public final class AggregationOverWindow { + private final RollingAggregationOnColumn wrapped; protected final WindowOptions windowOptions; - AggregationOverWindow(T wrapped, int columnIndex, WindowOptions windowOptions) { - super(wrapped, columnIndex); + AggregationOverWindow(RollingAggregationOnColumn wrapped, WindowOptions windowOptions) { + this.wrapped = wrapped; this.windowOptions = windowOptions; if (windowOptions == null) { @@ -43,23 +43,6 @@ public WindowOptions getWindowOptions() { return windowOptions; } - @Override - public AggregationOnColumn onColumn(int columnIndex) { - if (columnIndex == getColumnIndex()) { - return this; // NOOP - } else { - return new AggregationOverWindow(this.wrapped, columnIndex, windowOptions); - } - } - - @Override - public AggregationOverWindow overWindow(WindowOptions windowOptions) { - if (this.windowOptions.equals(windowOptions)) { - return this; - } - return new AggregationOverWindow(wrapped, columnIndex, windowOptions); - } - @Override public int hashCode() { return 31 * super.hashCode() + windowOptions.hashCode(); @@ -69,10 +52,22 @@ public int hashCode() { public boolean equals(Object other) { if (other == this) { return true; - } else if (other instanceof AggregationOnColumn) { - AggregationOnColumn o = (AggregationOnColumn) other; - return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; + } else if (other instanceof AggregationOverWindow) { + AggregationOverWindow o = (AggregationOverWindow) other; + return wrapped.equals(o.wrapped) && windowOptions.equals(o.windowOptions); } return false; } + + int getColumnIndex() { + return wrapped.getColumnIndex(); + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } } diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 4a1ed3a178e..55bd5ec5ff9 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1135,7 +1135,7 @@ public Scalar sum() { * of the specified type. */ public Scalar sum(DType outType) { - return reduce(Aggregation.sum(), outType); + return reduce(ReductionAggregation.sum(), outType); } /** @@ -1143,7 +1143,7 @@ public Scalar sum(DType outType) { * of the same type as this column. */ public Scalar min() { - return reduce(Aggregation.min(), type); + return reduce(ReductionAggregation.min(), type); } /** @@ -1160,7 +1160,7 @@ public Scalar min(DType outType) { return tmp.min(outType); } } - return reduce(Aggregation.min(), outType); + return reduce(ReductionAggregation.min(), outType); } /** @@ -1168,7 +1168,7 @@ public Scalar min(DType outType) { * of the same type as this column. */ public Scalar max() { - return reduce(Aggregation.max(), type); + return reduce(ReductionAggregation.max(), type); } /** @@ -1185,7 +1185,7 @@ public Scalar max(DType outType) { return tmp.max(outType); } } - return reduce(Aggregation.max(), outType); + return reduce(ReductionAggregation.max(), outType); } /** @@ -1201,7 +1201,7 @@ public Scalar product() { * of the specified type. */ public Scalar product(DType outType) { - return reduce(Aggregation.product(), outType); + return reduce(ReductionAggregation.product(), outType); } /** @@ -1217,7 +1217,7 @@ public Scalar sumOfSquares() { * scalar of the specified type. */ public Scalar sumOfSquares(DType outType) { - return reduce(Aggregation.sumOfSquares(), outType); + return reduce(ReductionAggregation.sumOfSquares(), outType); } /** @@ -1241,7 +1241,7 @@ public Scalar mean() { * types are currently supported. */ public Scalar mean(DType outType) { - return reduce(Aggregation.mean(), outType); + return reduce(ReductionAggregation.mean(), outType); } /** @@ -1265,7 +1265,7 @@ public Scalar variance() { * types are currently supported. */ public Scalar variance(DType outType) { - return reduce(Aggregation.variance(), outType); + return reduce(ReductionAggregation.variance(), outType); } /** @@ -1290,7 +1290,7 @@ public Scalar standardDeviation() { * types are currently supported. */ public Scalar standardDeviation(DType outType) { - return reduce(Aggregation.standardDeviation(), outType); + return reduce(ReductionAggregation.standardDeviation(), outType); } /** @@ -1309,7 +1309,7 @@ public Scalar any() { * Null values are skipped. */ public Scalar any(DType outType) { - return reduce(Aggregation.any(), outType); + return reduce(ReductionAggregation.any(), outType); } /** @@ -1330,7 +1330,7 @@ public Scalar all() { */ @Deprecated public Scalar all(DType outType) { - return reduce(Aggregation.all(), outType); + return reduce(ReductionAggregation.all(), outType); } /** @@ -1343,7 +1343,7 @@ public Scalar all(DType outType) { * empty or the reduction operation fails then the * {@link Scalar#isValid()} method of the result will return false. */ - public Scalar reduce(Aggregation aggregation) { + public Scalar reduce(ReductionAggregation aggregation) { return reduce(aggregation, type); } @@ -1360,7 +1360,7 @@ public Scalar reduce(Aggregation aggregation) { * empty or the reduction operation fails then the * {@link Scalar#isValid()} method of the result will return false. */ - public Scalar reduce(Aggregation aggregation, DType outType) { + public Scalar reduce(ReductionAggregation aggregation, DType outType) { long nativeId = aggregation.createNativeInstance(); try { return new Scalar(outType, reduce(getNativeView(), nativeId, outType.typeId.getNativeId(), outType.getScale())); @@ -1390,20 +1390,19 @@ public final ColumnVector quantile(QuantileMethod method, double[] quantiles) { * @throws IllegalArgumentException if unsupported window specification * (i.e. other than {@link WindowOptions.FrameType#ROWS} is used. */ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions options) { - Aggregation agg = op.getBaseAggregation(); // Check that only row-based windows are used. if (!options.getFrameType().equals(WindowOptions.FrameType.ROWS)) { throw new IllegalArgumentException("Expected ROWS-based window specification. Unexpected window type: " + options.getFrameType()); } - long nativePtr = agg.createNativeInstance(); + long nativePtr = op.createNativeInstance(); try { Scalar p = options.getPrecedingScalar(); Scalar f = options.getFollowingScalar(); return new ColumnVector( rollingWindow(this.getNativeView(), - agg.getDefaultOutput(), + op.getDefaultOutput(), options.getMinPeriods(), nativePtr, p == null || !p.isValid() ? 0 : p.getInt(), @@ -1420,7 +1419,7 @@ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions opt * This is just a convenience method for an inclusive scan with a SUM aggregation. */ public final ColumnVector prefixSum() { - return scan(Aggregation.sum()); + return scan(ScanAggregation.sum()); } /** @@ -1431,7 +1430,7 @@ public final ColumnVector prefixSum() { * null policy too. Currently none of those aggregations are supported so * it is undefined how they would interact with each other. */ - public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullPolicy nullPolicy) { + public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType, NullPolicy nullPolicy) { long nativeId = aggregation.createNativeInstance(); try { return new ColumnVector(scan(getNativeView(), nativeId, @@ -1446,7 +1445,7 @@ public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullP * @param aggregation the aggregation to perform * @param scanType should the scan be inclusive, include the current row, or exclusive. */ - public final ColumnVector scan(Aggregation aggregation, ScanType scanType) { + public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType) { return scan(aggregation, scanType, NullPolicy.EXCLUDE); } @@ -1454,7 +1453,7 @@ public final ColumnVector scan(Aggregation aggregation, ScanType scanType) { * Computes an inclusive scan for a column that excludes nulls. * @param aggregation the aggregation to perform */ - public final ColumnVector scan(Aggregation aggregation) { + public final ColumnVector scan(ScanAggregation aggregation) { return scan(aggregation, ScanType.INCLUSIVE, NullPolicy.EXCLUDE); } diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java new file mode 100644 index 00000000000..dd2adf8bee8 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -0,0 +1,296 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a reduce. + */ +public final class GroupByAggregation { + private final Aggregation wrapped; + + private GroupByAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + Aggregation getWrapped() { + return wrapped; + } + + + /** + * Add a column to the Aggregation so it can be used on a specific column of data. + * @param columnIndex the index of the column to operate on. + */ + public GroupByAggregationOnColumn onColumn(int columnIndex) { + return new GroupByAggregationOnColumn(this, columnIndex); + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof GroupByAggregation) { + GroupByAggregation o = (GroupByAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Count number of valid, a.k.a. non-null, elements. + */ + public static GroupByAggregation count() { + return new GroupByAggregation(Aggregation.count()); + } + + /** + * Count number of elements. + * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values + * should be counted. + */ + public static GroupByAggregation count(NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.count(nullPolicy)); + } + + /** + * Sum Aggregation + */ + public static GroupByAggregation sum() { + return new GroupByAggregation(Aggregation.sum()); + } + + /** + * Product Aggregation. + */ + public static GroupByAggregation product() { + return new GroupByAggregation(Aggregation.product()); + } + + + /** + * Index of max element. Please note that when using this aggregation if the + * data is not already sorted by the grouping keys it may be automatically sorted + * prior to doing the aggregation. This would result in an index into the sorted data being + * returned. + */ + public static GroupByAggregation argMax() { + return new GroupByAggregation(Aggregation.argMax()); + } + + /** + * Index of min element. Please note that when using this aggregation if the + * data is not already sorted by the grouping keys it may be automatically sorted + * prior to doing the aggregation. This would result in an index into the sorted data being + * returned. + */ + public static GroupByAggregation argMin() { + return new GroupByAggregation(Aggregation.argMin()); + } + + /** + * Min Aggregation + */ + public static GroupByAggregation min() { + return new GroupByAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static GroupByAggregation max() { + return new GroupByAggregation(Aggregation.max()); + } + + /** + * Arithmetic mean reduction. + */ + public static GroupByAggregation mean() { + return new GroupByAggregation(Aggregation.mean()); + } + + /** + * Sum of square of differences from mean. + */ + public static GroupByAggregation M2() { + return new GroupByAggregation(Aggregation.M2()); + } + + /** + * Variance aggregation with 1 as the delta degrees of freedom. + */ + public static GroupByAggregation variance() { + return new GroupByAggregation(Aggregation.variance()); + } + + /** + * Variance aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of variance is + * N - ddof, where N is the population size. + */ + public static GroupByAggregation variance(int ddof) { + return new GroupByAggregation(Aggregation.variance(ddof)); + } + + /** + * Standard deviation aggregation with 1 as the delta degrees of freedom. + */ + public static GroupByAggregation standardDeviation() { + return new GroupByAggregation(Aggregation.standardDeviation()); + } + + /** + * Standard deviation aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of std is + * N - ddof, where N is the population size. + */ + public static GroupByAggregation standardDeviation(int ddof) { + return new GroupByAggregation(Aggregation.standardDeviation(ddof)); + } + + /** + * Aggregate to compute the specified quantiles. Uses linear interpolation by default. + */ + public static GroupByAggregation quantile(double ... quantiles) { + return new GroupByAggregation(Aggregation.quantile(quantiles)); + } + + /** + * Aggregate to compute various quantiles. + */ + public static GroupByAggregation quantile(QuantileMethod method, double ... quantiles) { + return new GroupByAggregation(Aggregation.quantile(method, quantiles)); + } + + /** + * Median reduction. + */ + public static GroupByAggregation median() { + return new GroupByAggregation(Aggregation.median()); + } + + /** + * Number of unique, non-null, elements. + */ + public static GroupByAggregation nunique() { + return new GroupByAggregation(Aggregation.nunique()); + } + + /** + * Number of unique elements. + * @param nullPolicy INCLUDE if nulls should be counted else EXCLUDE. If nulls are counted they + * compare as equal so multiple null values in a range would all only + * increase the count by 1. + */ + public static GroupByAggregation nunique(NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.nunique(nullPolicy)); + } + + /** + * Get the nth, non-null, element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + */ + public static GroupByAggregation nth(int offset) { + return new GroupByAggregation(Aggregation.nth(offset)); + } + + /** + * Get the nth element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they + * should be skipped. + */ + public static GroupByAggregation nth(int offset, NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.nth(offset, nullPolicy)); + } + + /** + * Collect the values into a list. Nulls will be skipped. + */ + public static GroupByAggregation collectList() { + return new GroupByAggregation(Aggregation.collectList()); + } + + /** + * Collect the values into a list. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + */ + public static GroupByAggregation collectList(NullPolicy nullPolicy) { + return new GroupByAggregation(Aggregation.collectList(nullPolicy)); + } + + /** + * Collect the values into a set. All null values will be excluded, and all nan values are regarded as + * unique instances. + */ + public static GroupByAggregation collectSet() { + return new GroupByAggregation(Aggregation.collectSet()); + } + + /** + * Collect the values into a set. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static GroupByAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + return new GroupByAggregation(Aggregation.collectSet(nullPolicy, nullEquality, nanEquality)); + } + + /** + * Merge the partial lists produced by multiple CollectListAggregations. + * NOTICE: The partial lists to be merged should NOT include any null list element (but can include null list entries). + */ + public static GroupByAggregation mergeLists() { + return new GroupByAggregation(Aggregation.mergeLists()); + } + + /** + * Merge the partial sets produced by multiple CollectSetAggregations. Each null/nan value will be regarded as + * a unique instance. + */ + public static GroupByAggregation mergeSets() { + return new GroupByAggregation(Aggregation.mergeSets()); + } + + /** + * Merge the partial sets produced by multiple CollectSetAggregations. + * + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static GroupByAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) { + return new GroupByAggregation(Aggregation.mergeSets(nullEquality, nanEquality)); + } + + /** + * Merge the partial M2 values produced by multiple instances of M2Aggregation. + */ + public static GroupByAggregation mergeM2() { + return new GroupByAggregation(Aggregation.mergeM2()); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java new file mode 100644 index 00000000000..c50cf3728f0 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregationOnColumn.java @@ -0,0 +1,56 @@ +/* + * + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * A GroupByAggregation for a specific column in a table. + */ +public final class GroupByAggregationOnColumn { + protected final GroupByAggregation wrapped; + protected final int columnIndex; + + GroupByAggregationOnColumn(GroupByAggregation wrapped, int columnIndex) { + this.wrapped = wrapped; + this.columnIndex = columnIndex; + } + + public int getColumnIndex() { + return columnIndex; + } + + GroupByAggregation getWrapped() { + return wrapped; + } + + @Override + public int hashCode() { + return 31 * wrapped.hashCode() + columnIndex; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof GroupByAggregationOnColumn) { + GroupByAggregationOnColumn o = (GroupByAggregationOnColumn) other; + return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; + } + return false; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java new file mode 100644 index 00000000000..219b6dde05d --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregation.java @@ -0,0 +1,118 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a grouped scan. + */ +public final class GroupByScanAggregation { + private final Aggregation wrapped; + + private GroupByScanAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + Aggregation getWrapped() { + return wrapped; + } + + /** + * Add a column to the Aggregation so it can be used on a specific column of data. + * @param columnIndex the index of the column to operate on. + */ + public GroupByScanAggregationOnColumn onColumn(int columnIndex) { + return new GroupByScanAggregationOnColumn(this, columnIndex); + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof GroupByScanAggregation) { + GroupByScanAggregation o = (GroupByScanAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Sum Aggregation + */ + public static GroupByScanAggregation sum() { + return new GroupByScanAggregation(Aggregation.sum()); + } + + + /** + * Product Aggregation. + */ + public static GroupByScanAggregation product() { + return new GroupByScanAggregation(Aggregation.product()); + } + + /** + * Min Aggregation + */ + public static GroupByScanAggregation min() { + return new GroupByScanAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static GroupByScanAggregation max() { + return new GroupByScanAggregation(Aggregation.max()); + } + + /** + * Count number of elements. + * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values + * should be counted. + */ + public static GroupByScanAggregation count(NullPolicy nullPolicy) { + return new GroupByScanAggregation(Aggregation.count(nullPolicy)); + } + + /** + * Get the row's ranking. + */ + public static GroupByScanAggregation rank() { + return new GroupByScanAggregation(Aggregation.rank()); + } + + /** + * Get the row's dense ranking. + */ + public static GroupByScanAggregation denseRank() { + return new GroupByScanAggregation(Aggregation.denseRank()); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java new file mode 100644 index 00000000000..75e4936e5b9 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/GroupByScanAggregationOnColumn.java @@ -0,0 +1,64 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * A GroupByScanAggregation for a specific column in a table. + */ +public final class GroupByScanAggregationOnColumn { + protected final GroupByScanAggregation wrapped; + protected final int columnIndex; + + GroupByScanAggregationOnColumn(GroupByScanAggregation wrapped, int columnIndex) { + this.wrapped = wrapped; + this.columnIndex = columnIndex; + } + + public int getColumnIndex() { + return columnIndex; + } + + @Override + public int hashCode() { + return 31 * wrapped.hashCode() + columnIndex; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof GroupByScanAggregationOnColumn) { + GroupByScanAggregationOnColumn o = (GroupByScanAggregationOnColumn) other; + return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; + } + return false; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + GroupByScanAggregation getWrapped() { + return wrapped; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java new file mode 100644 index 00000000000..7eff85dcd0d --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java @@ -0,0 +1,212 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a reduce. + */ +public final class ReductionAggregation { + private final Aggregation wrapped; + + private ReductionAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + Aggregation getWrapped() { + return wrapped; + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof ReductionAggregation) { + ReductionAggregation o = (ReductionAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Sum Aggregation + */ + public static ReductionAggregation sum() { + return new ReductionAggregation(Aggregation.sum()); + } + + /** + * Product Aggregation. + */ + public static ReductionAggregation product() { + return new ReductionAggregation(Aggregation.product()); + } + + /** + * Min Aggregation + */ + public static ReductionAggregation min() { + return new ReductionAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static ReductionAggregation max() { + return new ReductionAggregation(Aggregation.max()); + } + + /** + * Any reduction. Produces a true or 1, depending on the output type, + * if any of the elements in the range are true or non-zero, otherwise produces a false or 0. + * Null values are skipped. + */ + public static ReductionAggregation any() { + return new ReductionAggregation(Aggregation.any()); + } + + /** + * All reduction. Produces true or 1, depending on the output type, if all of the elements in + * the range are true or non-zero, otherwise produces a false or 0. + * Null values are skipped. + */ + public static ReductionAggregation all() { + return new ReductionAggregation(Aggregation.all()); + } + + + /** + * Sum of squares reduction. + */ + public static ReductionAggregation sumOfSquares() { + return new ReductionAggregation(Aggregation.sumOfSquares()); + } + + /** + * Arithmetic mean reduction. + */ + public static ReductionAggregation mean() { + return new ReductionAggregation(Aggregation.mean()); + } + + + /** + * Variance aggregation with 1 as the delta degrees of freedom. + */ + public static ReductionAggregation variance() { + return new ReductionAggregation(Aggregation.variance()); + } + + /** + * Variance aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of variance is + * N - ddof, where N is the population size. + */ + public static ReductionAggregation variance(int ddof) { + return new ReductionAggregation(Aggregation.variance(ddof)); + } + + /** + * Standard deviation aggregation with 1 as the delta degrees of freedom. + */ + public static ReductionAggregation standardDeviation() { + return new ReductionAggregation(Aggregation.standardDeviation()); + } + + /** + * Standard deviation aggregation. + * @param ddof delta degrees of freedom. The divisor used in calculation of std is + * N - ddof, where N is the population size. + */ + public static ReductionAggregation standardDeviation(int ddof) { + return new ReductionAggregation(Aggregation.standardDeviation(ddof)); + } + + + /** + * Median reduction. + */ + public static ReductionAggregation median() { + return new ReductionAggregation(Aggregation.median()); + } + + /** + * Aggregate to compute the specified quantiles. Uses linear interpolation by default. + */ + public static ReductionAggregation quantile(double ... quantiles) { + return new ReductionAggregation(Aggregation.quantile(quantiles)); + } + + /** + * Aggregate to compute various quantiles. + */ + public static ReductionAggregation quantile(QuantileMethod method, double ... quantiles) { + return new ReductionAggregation(Aggregation.quantile(method, quantiles)); + } + + + /** + * Number of unique, non-null, elements. + */ + public static ReductionAggregation nunique() { + return new ReductionAggregation(Aggregation.nunique()); + } + + /** + * Number of unique elements. + * @param nullPolicy INCLUDE if nulls should be counted else EXCLUDE. If nulls are counted they + * compare as equal so multiple null values in a range would all only + * increase the count by 1. + */ + public static ReductionAggregation nunique(NullPolicy nullPolicy) { + return new ReductionAggregation(Aggregation.nunique(nullPolicy)); + } + + /** + * Get the nth, non-null, element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + */ + public static ReductionAggregation nth(int offset) { + return new ReductionAggregation(Aggregation.nth(offset)); + } + + /** + * Get the nth element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they + * should be skipped. + */ + public static ReductionAggregation nth(int offset, NullPolicy nullPolicy) { + return new ReductionAggregation(Aggregation.nth(offset, nullPolicy)); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java index 9b80924463a..07983f77aad 100644 --- a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java @@ -19,11 +19,189 @@ package ai.rapids.cudf; /** - * Used to tag an aggregation as something that is compatible with rolling window operations. - * Do not try to implement this yourself + * An aggregation that can be used on rolling windows. */ -public interface RollingAggregation { - default T getBaseAggregation() { - return (T)this; +public final class RollingAggregation { + private final Aggregation wrapped; + + private RollingAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + /** + * Add a column to the Aggregation so it can be used on a specific column of data. + * @param columnIndex the index of the column to operate on. + */ + public RollingAggregationOnColumn onColumn(int columnIndex) { + return new RollingAggregationOnColumn(this, columnIndex); + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof RollingAggregation) { + RollingAggregation o = (RollingAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Rolling Window Sum + */ + public static RollingAggregation sum() { + return new RollingAggregation(Aggregation.sum()); + } + + + /** + * Rolling Window Min + */ + public static RollingAggregation min() { + return new RollingAggregation(Aggregation.min()); + } + + /** + * Rolling Window Max + */ + public static RollingAggregation max() { + return new RollingAggregation(Aggregation.max()); + } + + + /** + * Count number of valid, a.k.a. non-null, elements. + */ + public static RollingAggregation count() { + return new RollingAggregation(Aggregation.count()); + } + + /** + * Count number of elements. + * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values + * should be counted. + */ + public static RollingAggregation count(NullPolicy nullPolicy) { + return new RollingAggregation(Aggregation.count(nullPolicy)); + } + + /** + * Arithmetic Mean + */ + public static RollingAggregation mean() { + return new RollingAggregation(Aggregation.mean()); + } + + + /** + * Index of max element. + */ + public static RollingAggregation argMax() { + return new RollingAggregation(Aggregation.argMax()); + } + + /** + * Index of min element. + */ + public static RollingAggregation argMin() { + return new RollingAggregation(Aggregation.argMin()); + } + + + /** + * Get the row number. + */ + public static RollingAggregation rowNumber() { + return new RollingAggregation(Aggregation.rowNumber()); + } + + + /** + * In a rolling window return the value offset entries ahead or null if it is outside of the + * window. + */ + public static RollingAggregation lead(int offset) { + return lead(offset, null); + } + + /** + * In a rolling window return the value offset entries ahead or the corresponding value from + * defaultOutput if it is outside of the window. Note that this does not take any ownership of + * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life + * time of this aggregation operation. + */ + public static RollingAggregation lead(int offset, ColumnVector defaultOutput) { + return new RollingAggregation(Aggregation.lead(offset, defaultOutput)); + } + + + + /** + * In a rolling window return the value offset entries behind or null if it is outside of the + * window. + */ + public static RollingAggregation lag(int offset) { + return lag(offset, null); + } + + /** + * In a rolling window return the value offset entries behind or the corresponding value from + * defaultOutput if it is outside of the window. Note that this does not take any ownership of + * defaultOutput and the caller mush ensure that defaultOutput remains valid during the life + * time of this aggregation operation. + */ + public static RollingAggregation lag(int offset, ColumnVector defaultOutput) { + return new RollingAggregation(Aggregation.lag(offset, defaultOutput)); + } + + + /** + * Collect the values into a list. Nulls will be skipped. + */ + public static RollingAggregation collectList() { + return new RollingAggregation(Aggregation.collectList()); + } + + /** + * Collect the values into a list. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + */ + public static RollingAggregation collectList(NullPolicy nullPolicy) { + return new RollingAggregation(Aggregation.collectList(nullPolicy)); + } + + + /** + * Collect the values into a set. All null values will be excluded, and all nan values are regarded as + * unique instances. + */ + public static RollingAggregation collectSet() { + return new RollingAggregation(Aggregation.collectSet()); + } + + /** + * Collect the values into a set. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static RollingAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + return new RollingAggregation(Aggregation.collectSet(nullPolicy, nullEquality, nanEquality)); } } diff --git a/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java b/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java similarity index 55% rename from java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java rename to java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java index bb1404e5a07..a6b1484aa71 100644 --- a/java/src/main/java/ai/rapids/cudf/AggregationOnColumn.java +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregationOnColumn.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,47 +19,24 @@ package ai.rapids.cudf; /** - * An Aggregation instance that also holds a column number so the aggregation can be done on - * a specific column of data in a table. + * A RollingAggregation for a specific column in a table. */ -public class AggregationOnColumn extends Aggregation { - protected final T wrapped; +public final class RollingAggregationOnColumn { + protected final RollingAggregation wrapped; protected final int columnIndex; - AggregationOnColumn(T wrapped, int columnIndex) { - super(wrapped.kind); + RollingAggregationOnColumn(RollingAggregation wrapped, int columnIndex) { this.wrapped = wrapped; this.columnIndex = columnIndex; } - @Override - public AggregationOnColumn onColumn(int columnIndex) { - if (columnIndex == getColumnIndex()) { - return this; // NOOP - } else { - return new AggregationOnColumn(this.wrapped, columnIndex); - } - } - - /** - * Do the aggregation over a given Window. - */ - public > AggregationOverWindow overWindow(WindowOptions windowOptions) { - return new AggregationOverWindow(wrapped, columnIndex, windowOptions); - } - public int getColumnIndex() { return columnIndex; } - @Override - long createNativeInstance() { - return wrapped.createNativeInstance(); - } - @Override - long getDefaultOutput() { - return wrapped.getDefaultOutput(); + public AggregationOverWindow overWindow(WindowOptions windowOptions) { + return new AggregationOverWindow(this, windowOptions); } @Override @@ -71,10 +48,18 @@ public int hashCode() { public boolean equals(Object other) { if (other == this) { return true; - } else if (other instanceof AggregationOnColumn) { - AggregationOnColumn o = (AggregationOnColumn) other; + } else if (other instanceof RollingAggregationOnColumn) { + RollingAggregationOnColumn o = (RollingAggregationOnColumn) other; return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex; } return false; } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } } diff --git a/java/src/main/java/ai/rapids/cudf/ScanAggregation.java b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java new file mode 100644 index 00000000000..08489562adc --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ScanAggregation.java @@ -0,0 +1,100 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a scan. + */ +public final class ScanAggregation { + private final Aggregation wrapped; + + private ScanAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + Aggregation getWrapped() { + return wrapped; + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof ScanAggregation) { + ScanAggregation o = (ScanAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Sum Aggregation + */ + public static ScanAggregation sum() { + return new ScanAggregation(Aggregation.sum()); + } + + /** + * Product Aggregation. + */ + public static ScanAggregation product() { + return new ScanAggregation(Aggregation.product()); + } + + /** + * Min Aggregation + */ + public static ScanAggregation min() { + return new ScanAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static ScanAggregation max() { + return new ScanAggregation(Aggregation.max()); + } + + /** + * Get the row's ranking. + */ + public static ScanAggregation rank() { + return new ScanAggregation(Aggregation.rank()); + } + + /** + * Get the row's dense ranking. + */ + public static ScanAggregation denseRank() { + return new ScanAggregation(Aggregation.denseRank()); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 96a9b608f06..360bb5c7467 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2456,7 +2456,7 @@ public static final class GroupByOperation { * 1, 2 * 2, 1 ==> aggregated count */ - public Table aggregate(AggregationOnColumn... aggregates) { + public Table aggregate(GroupByAggregationOnColumn... aggregates) { assert aggregates != null; // To improve performance and memory we want to remove duplicate operations @@ -2469,9 +2469,9 @@ public Table aggregate(AggregationOnColumn... aggregates) { int keysLength = operation.indices.length; int totalOps = 0; for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) { - AggregationOnColumn agg = aggregates[outputIndex]; + GroupByAggregationOnColumn agg = aggregates[outputIndex]; ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps()); - totalOps += ops.add(agg, outputIndex + keysLength); + totalOps += ops.add(agg.getWrapped().getWrapped(), outputIndex + keysLength); } int[] aggColumnIndexes = new int[totalOps]; long[] aggOperationInstances = new long[totalOps]; @@ -2808,7 +2808,7 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate } } - public Table scan(AggregationOnColumn... aggregates) { + public Table scan(GroupByScanAggregationOnColumn... aggregates) { assert aggregates != null; // To improve performance and memory we want to remove duplicate operations @@ -2821,9 +2821,9 @@ public Table scan(AggregationOnColumn... aggregates) { int keysLength = operation.indices.length; int totalOps = 0; for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) { - AggregationOnColumn agg = aggregates[outputIndex]; + GroupByScanAggregationOnColumn agg = aggregates[outputIndex]; ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps()); - totalOps += ops.add(agg, outputIndex + keysLength); + totalOps += ops.add(agg.getWrapped().getWrapped(), outputIndex + keysLength); } int[] aggColumnIndexes = new int[totalOps]; long[] aggOperationInstances = new long[totalOps]; diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index d3fdb0e19bb..4856071e296 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2899,24 +2899,22 @@ void testPrefixSum() { @Test void testScanSum() { try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled -// try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } - - try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } + + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, 6, 11, 19, 29)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, 3, 3, 6, 11, 19)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, 3, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, null, 3, 6, 11, 19)) { assertColumnsAreEqual(expected, result); } @@ -2925,25 +2923,23 @@ void testScanSum() { @Test void testScanMax() { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { -// try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, 2, 2, 3, 5, 8)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, 2, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.max(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, null, 2, 3, 5, 8)) { assertColumnsAreEqual(expected, result); } @@ -2952,25 +2948,23 @@ void testScanMax() { @Test void testScanMin() { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { -// try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, 1, 1, 1, 1)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, 1, 1, 1, 1, 1)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, 1, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.min(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, null, 1, 1, 1, 1)) { assertColumnsAreEqual(expected, result); } @@ -2979,25 +2973,23 @@ void testScanMin() { @Test void testScanProduct() { - // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE - // tests have been disabled try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { -// try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 6, 30, 240, 2400)) { assertColumnsAreEqual(expected, result); } -// try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); -// ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, 2, 2, 6, 30, 240)) { -// assertColumnsAreEqual(expected, result); -// } + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, 2, null, null, null, null)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + try (ColumnVector result = v1.scan(ScanAggregation.product(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, 2, 6, 30, 240)) { assertColumnsAreEqual(expected, result); } @@ -3011,13 +3003,13 @@ void testScanRank() { ColumnVector struct_order = ColumnVector.makeStruct(col1, col2); ColumnVector expected = ColumnVector.fromBoxedInts( 1, 1, 3, 4, 5, 6, 7, 7, 9, 9, 11, 12)) { - try (ColumnVector result = struct_order.scan(Aggregation.rank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.rank(), ScanType.INCLUSIVE, NullPolicy.INCLUDE)) { assertColumnsAreEqual(expected, result); } // Exclude should have identical results - try (ColumnVector result = struct_order.scan(Aggregation.rank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.rank(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE) ) { assertColumnsAreEqual(expected, result); @@ -3034,13 +3026,13 @@ void testScanDenseRank() { ColumnVector struct_order = ColumnVector.makeStruct(col1, col2); ColumnVector expected = ColumnVector.fromBoxedInts( 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8, 9)) { - try (ColumnVector result = struct_order.scan(Aggregation.denseRank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.denseRank(), ScanType.INCLUSIVE, NullPolicy.INCLUDE)) { assertColumnsAreEqual(expected, result); } // Exclude should have identical results - try (ColumnVector result = struct_order.scan(Aggregation.denseRank(), + try (ColumnVector result = struct_order.scan(ScanAggregation.denseRank(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE)) { assertColumnsAreEqual(expected, result); } @@ -3058,39 +3050,39 @@ void testWindowStatic() { .minPeriods(2).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8)) { try (ColumnVector expected = ColumnVector.fromLongs(9, 16, 17, 21, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.sum(), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 6, 6); - ColumnVector result = v1.rollingWindow(Aggregation.min(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.min(), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromInts(5, 7, 7, 8, 8); - ColumnVector result = v1.rollingWindow(Aggregation.max(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.max(), options)) { assertColumnsAreEqual(expected, result); } // The rolling window produces the same result type as the input try (ColumnVector expected = ColumnVector.fromDoubles(4.5, 16.0 / 3, 17.0 / 3, 7, 7); - ColumnVector result = v1.rollingWindow(Aggregation.mean(), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.mean(), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromBoxedInts(4, 7, 6, 8, null); - ColumnVector result = v1.rollingWindow(Aggregation.lead(1), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.lead(1), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.lag(1), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector defaultOutput = ColumnVector.fromInts(-1, -2, -3, -4, -5); ColumnVector expected = ColumnVector.fromBoxedInts(-1, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1, defaultOutput), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.lag(1, defaultOutput), options)) { assertColumnsAreEqual(expected, result); } } @@ -3106,11 +3098,11 @@ void testWindowStaticCounts() { .minPeriods(2).build()) { try (ColumnVector v1 = ColumnVector.fromBoxedInts(5, 4, null, 6, 8)) { try (ColumnVector expected = ColumnVector.fromInts(2, 2, 2, 2, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.EXCLUDE), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.count(NullPolicy.EXCLUDE), options)) { assertColumnsAreEqual(expected, result); } try (ColumnVector expected = ColumnVector.fromInts(2, 3, 3, 3, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.INCLUDE), options)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.count(NullPolicy.INCLUDE), options)) { assertColumnsAreEqual(expected, result); } } @@ -3125,7 +3117,7 @@ void testWindowDynamicNegative() { .minPeriods(2).window(precedingCol, followingCol).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); ColumnVector expected = ColumnVector.fromBoxedLongs(null, null, 9L, 16L, 25L); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.sum(), window)) { assertColumnsAreEqual(expected, result); } } @@ -3141,7 +3133,7 @@ void testWindowLag() { .window(two, negOne).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.max(), window)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.max(), window)) { assertColumnsAreEqual(expected, result); } } @@ -3155,7 +3147,7 @@ void testWindowDynamic() { .window(precedingCol, followingCol).build()) { try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); ColumnVector expected = ColumnVector.fromLongs(16, 22, 30, 14, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + ColumnVector result = v1.rollingWindow(RollingAggregation.sum(), window)) { assertColumnsAreEqual(expected, result); } } @@ -3181,7 +3173,7 @@ void testWindowThrowsException() { .minPeriods(1) .orderByColumnIndex(0) .build()) { - arraywindowCol.rollingWindow(Aggregation.sum(), options); + arraywindowCol.rollingWindow(RollingAggregation.sum(), options); } }); } diff --git a/java/src/test/java/ai/rapids/cudf/ReductionTest.java b/java/src/test/java/ai/rapids/cudf/ReductionTest.java index 17b9ec3556f..2b26597c8f7 100644 --- a/java/src/test/java/ai/rapids/cudf/ReductionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ReductionTest.java @@ -43,17 +43,17 @@ class ReductionTest extends CudfTestBase { Aggregation.Kind.ANY, Aggregation.Kind.ALL); - private static Scalar buildExpectedScalar(Aggregation op, DType baseType, Object expectedObject) { + private static Scalar buildExpectedScalar(ReductionAggregation op, DType baseType, Object expectedObject) { if (expectedObject == null) { return Scalar.fromNull(baseType); } - if (FLOAT_REDUCTIONS.contains(op.kind)) { + if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { if (baseType.equals(DType.FLOAT32)) { return Scalar.fromFloat((Float) expectedObject); } return Scalar.fromDouble((Double) expectedObject); } - if (BOOL_REDUCTIONS.contains(op.kind)) { + if (BOOL_REDUCTIONS.contains(op.getWrapped().kind)) { return Scalar.fromBool((Boolean) expectedObject); } switch (baseType.typeId) { @@ -88,165 +88,165 @@ private static Scalar buildExpectedScalar(Aggregation op, DType baseType, Object private static Stream createBooleanParams() { Boolean[] vals = new Boolean[]{true, true, null, false, true, false, null}; return Stream.of( - Arguments.of(Aggregation.sum(), new Boolean[0], null, 0.), - Arguments.of(Aggregation.sum(), new Boolean[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, true, 0.), - Arguments.of(Aggregation.min(), vals, false, 0.), - Arguments.of(Aggregation.max(), vals, true, 0.), - Arguments.of(Aggregation.product(), vals, false, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, true, 0.), - Arguments.of(Aggregation.mean(), vals, 0.6, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 0.5477225575051662, DELTAD), - Arguments.of(Aggregation.variance(), vals, 0.3, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, false, 0.) + Arguments.of(ReductionAggregation.sum(), new Boolean[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Boolean[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, true, 0.), + Arguments.of(ReductionAggregation.min(), vals, false, 0.), + Arguments.of(ReductionAggregation.max(), vals, true, 0.), + Arguments.of(ReductionAggregation.product(), vals, false, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, true, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 0.6, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 0.5477225575051662, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 0.3, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, false, 0.) ); } private static Stream createByteParams() { Byte[] vals = new Byte[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.sum(), new Byte[0], null, 0.), - Arguments.of(Aggregation.sum(), new Byte[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, (byte) 83, 0.), - Arguments.of(Aggregation.min(), vals, (byte) -1, 0.), - Arguments.of(Aggregation.max(), vals, (byte) 123, 0.), - Arguments.of(Aggregation.product(), vals, (byte) 160, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, (byte) 47, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Byte[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Byte[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, (byte) 83, 0.), + Arguments.of(ReductionAggregation.min(), vals, (byte) -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, (byte) 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, (byte) 160, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, (byte) 47, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.) ); } private static Stream createShortParams() { Short[] vals = new Short[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.sum(), new Short[0], null, 0.), - Arguments.of(Aggregation.sum(), new Short[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, (short) 339, 0.), - Arguments.of(Aggregation.min(), vals, (short) -1, 0.), - Arguments.of(Aggregation.max(), vals, (short) 123, 0.), - Arguments.of(Aggregation.product(), vals, (short) -22624, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, (short) 31279, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Short[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Short[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, (short) 339, 0.), + Arguments.of(ReductionAggregation.min(), vals, (short) -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, (short) 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, (short) -22624, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, (short) 31279, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.) ); } private static Stream createIntParams() { Integer[] vals = new Integer[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.sum(), new Integer[0], null, 0.), - Arguments.of(Aggregation.sum(), new Integer[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, 339, 0.), - Arguments.of(Aggregation.min(), vals, -1, 0.), - Arguments.of(Aggregation.max(), vals, 123, 0.), - Arguments.of(Aggregation.product(), vals, -258300000, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Integer[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Integer[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, 339, 0.), + Arguments.of(ReductionAggregation.min(), vals, -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, -258300000, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.) ); } private static Stream createLongParams() { Long[] vals = new Long[]{-1L, 7L, 123L, null, 50L, 60L, 100L}; return Stream.of( - Arguments.of(Aggregation.sum(), new Long[0], null, 0.), - Arguments.of(Aggregation.sum(), new Long[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, 339L, 0.), - Arguments.of(Aggregation.min(), vals, -1L, 0.), - Arguments.of(Aggregation.max(), vals, 123L, 0.), - Arguments.of(Aggregation.product(), vals, -258300000L, 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279L, 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.), - Arguments.of(Aggregation.quantile(0.5), vals, 55.0, DELTAD), - Arguments.of(Aggregation.quantile(0.9), vals, 111.5, DELTAD) + Arguments.of(ReductionAggregation.sum(), new Long[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Long[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, 339L, 0.), + Arguments.of(ReductionAggregation.min(), vals, -1L, 0.), + Arguments.of(ReductionAggregation.max(), vals, 123L, 0.), + Arguments.of(ReductionAggregation.product(), vals, -258300000L, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279L, 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.), + Arguments.of(ReductionAggregation.quantile(0.5), vals, 55.0, DELTAD), + Arguments.of(ReductionAggregation.quantile(0.9), vals, 111.5, DELTAD) ); } private static Stream createFloatParams() { Float[] vals = new Float[]{-1f, 7f, 123f, null, 50f, 60f, 100f}; return Stream.of( - Arguments.of(Aggregation.sum(), new Float[0], null, 0f), - Arguments.of(Aggregation.sum(), new Float[]{null, null, null}, null, 0f), - Arguments.of(Aggregation.sum(), vals, 339f, 0f), - Arguments.of(Aggregation.min(), vals, -1f, 0f), - Arguments.of(Aggregation.max(), vals, 123f, 0f), - Arguments.of(Aggregation.product(), vals, -258300000f, 0f), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279f, 0f), - Arguments.of(Aggregation.mean(), vals, 56.5f, DELTAF), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839f, DELTAF), - Arguments.of(Aggregation.variance(), vals, 2425.1f, DELTAF), - Arguments.of(Aggregation.any(), vals, true, 0f), - Arguments.of(Aggregation.all(), vals, true, 0f) + Arguments.of(ReductionAggregation.sum(), new Float[0], null, 0f), + Arguments.of(ReductionAggregation.sum(), new Float[]{null, null, null}, null, 0f), + Arguments.of(ReductionAggregation.sum(), vals, 339f, 0f), + Arguments.of(ReductionAggregation.min(), vals, -1f, 0f), + Arguments.of(ReductionAggregation.max(), vals, 123f, 0f), + Arguments.of(ReductionAggregation.product(), vals, -258300000f, 0f), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279f, 0f), + Arguments.of(ReductionAggregation.mean(), vals, 56.5f, DELTAF), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839f, DELTAF), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1f, DELTAF), + Arguments.of(ReductionAggregation.any(), vals, true, 0f), + Arguments.of(ReductionAggregation.all(), vals, true, 0f) ); } private static Stream createDoubleParams() { Double[] vals = new Double[]{-1., 7., 123., null, 50., 60., 100.}; return Stream.of( - Arguments.of(Aggregation.sum(), new Double[0], null, 0.), - Arguments.of(Aggregation.sum(), new Double[]{null, null, null}, null, 0.), - Arguments.of(Aggregation.sum(), vals, 339., 0.), - Arguments.of(Aggregation.min(), vals, -1., 0.), - Arguments.of(Aggregation.max(), vals, 123., 0.), - Arguments.of(Aggregation.product(), vals, -258300000., 0.), - Arguments.of(Aggregation.sumOfSquares(), vals, 31279., 0.), - Arguments.of(Aggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(Aggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(Aggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(Aggregation.any(), vals, true, 0.), - Arguments.of(Aggregation.all(), vals, true, 0.), - Arguments.of(Aggregation.quantile(0.5), vals, 55.0, DELTAD), - Arguments.of(Aggregation.quantile(0.9), vals, 111.5, DELTAD) + Arguments.of(ReductionAggregation.sum(), new Double[0], null, 0.), + Arguments.of(ReductionAggregation.sum(), new Double[]{null, null, null}, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, 339., 0.), + Arguments.of(ReductionAggregation.min(), vals, -1., 0.), + Arguments.of(ReductionAggregation.max(), vals, 123., 0.), + Arguments.of(ReductionAggregation.product(), vals, -258300000., 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279., 0.), + Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, true, 0.), + Arguments.of(ReductionAggregation.quantile(0.5), vals, 55.0, DELTAD), + Arguments.of(ReductionAggregation.quantile(0.9), vals, 111.5, DELTAD) ); } private static Stream createTimestampDaysParams() { Integer[] vals = new Integer[]{-1, 7, 123, null, 50, 60, 100}; return Stream.of( - Arguments.of(Aggregation.max(), new Integer[0], null), - Arguments.of(Aggregation.max(), new Integer[]{null, null, null}, null), - Arguments.of(Aggregation.max(), vals, 123), - Arguments.of(Aggregation.min(), vals, -1) + Arguments.of(ReductionAggregation.max(), new Integer[0], null), + Arguments.of(ReductionAggregation.max(), new Integer[]{null, null, null}, null), + Arguments.of(ReductionAggregation.max(), vals, 123), + Arguments.of(ReductionAggregation.min(), vals, -1) ); } private static Stream createTimestampResolutionParams() { Long[] vals = new Long[]{-1L, 7L, 123L, null, 50L, 60L, 100L}; return Stream.of( - Arguments.of(Aggregation.max(), new Long[0], null), - Arguments.of(Aggregation.max(), new Long[]{null, null, null}, null), - Arguments.of(Aggregation.min(), vals, -1L), - Arguments.of(Aggregation.max(), vals, 123L) + Arguments.of(ReductionAggregation.max(), new Long[0], null), + Arguments.of(ReductionAggregation.max(), new Long[]{null, null, null}, null), + Arguments.of(ReductionAggregation.min(), vals, -1L), + Arguments.of(ReductionAggregation.max(), vals, 123L) ); } - private static void assertEqualsDelta(Aggregation op, Scalar expected, Scalar result, + private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, Double percentage) { - if (FLOAT_REDUCTIONS.contains(op.kind)) { + if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getDouble(), result.getDouble(), percentage); } else { assertEquals(expected, result); } } - private static void assertEqualsDelta(Aggregation op, Scalar expected, Scalar result, + private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, Float percentage) { - if (FLOAT_REDUCTIONS.contains(op.kind)) { + if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getFloat(), result.getFloat(), percentage); } else { assertEquals(expected, result); @@ -255,7 +255,7 @@ private static void assertEqualsDelta(Aggregation op, Scalar expected, Scalar re @ParameterizedTest @MethodSource("createBooleanParams") - void testBoolean(Aggregation op, Boolean[] values, Object expectedObject, Double delta) { + void testBoolean(ReductionAggregation op, Boolean[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.BOOL8, expectedObject); ColumnVector v = ColumnVector.fromBoxedBooleans(values); Scalar result = v.reduce(op, expected.getType())) { @@ -265,7 +265,7 @@ void testBoolean(Aggregation op, Boolean[] values, Object expectedObject, Double @ParameterizedTest @MethodSource("createByteParams") - void testByte(Aggregation op, Byte[] values, Object expectedObject, Double delta) { + void testByte(ReductionAggregation op, Byte[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT8, expectedObject); ColumnVector v = ColumnVector.fromBoxedBytes(values); Scalar result = v.reduce(op, expected.getType())) { @@ -275,7 +275,7 @@ void testByte(Aggregation op, Byte[] values, Object expectedObject, Double delta @ParameterizedTest @MethodSource("createShortParams") - void testShort(Aggregation op, Short[] values, Object expectedObject, Double delta) { + void testShort(ReductionAggregation op, Short[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT16, expectedObject); ColumnVector v = ColumnVector.fromBoxedShorts(values); Scalar result = v.reduce(op, expected.getType())) { @@ -285,7 +285,7 @@ void testShort(Aggregation op, Short[] values, Object expectedObject, Double del @ParameterizedTest @MethodSource("createIntParams") - void testInt(Aggregation op, Integer[] values, Object expectedObject, Double delta) { + void testInt(ReductionAggregation op, Integer[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT32, expectedObject); ColumnVector v = ColumnVector.fromBoxedInts(values); Scalar result = v.reduce(op, expected.getType())) { @@ -295,7 +295,7 @@ void testInt(Aggregation op, Integer[] values, Object expectedObject, Double del @ParameterizedTest @MethodSource("createLongParams") - void testLong(Aggregation op, Long[] values, Object expectedObject, Double delta) { + void testLong(ReductionAggregation op, Long[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.INT64, expectedObject); ColumnVector v = ColumnVector.fromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -305,7 +305,7 @@ void testLong(Aggregation op, Long[] values, Object expectedObject, Double delta @ParameterizedTest @MethodSource("createFloatParams") - void testFloat(Aggregation op, Float[] values, Object expectedObject, Float delta) { + void testFloat(ReductionAggregation op, Float[] values, Object expectedObject, Float delta) { try (Scalar expected = buildExpectedScalar(op, DType.FLOAT32, expectedObject); ColumnVector v = ColumnVector.fromBoxedFloats(values); Scalar result = v.reduce(op, expected.getType())) { @@ -315,7 +315,7 @@ void testFloat(Aggregation op, Float[] values, Object expectedObject, Float delt @ParameterizedTest @MethodSource("createDoubleParams") - void testDouble(Aggregation op, Double[] values, Object expectedObject, Double delta) { + void testDouble(ReductionAggregation op, Double[] values, Object expectedObject, Double delta) { try (Scalar expected = buildExpectedScalar(op, DType.FLOAT64, expectedObject); ColumnVector v = ColumnVector.fromBoxedDoubles(values); Scalar result = v.reduce(op, expected.getType())) { @@ -325,7 +325,7 @@ void testDouble(Aggregation op, Double[] values, Object expectedObject, Double d @ParameterizedTest @MethodSource("createTimestampDaysParams") - void testTimestampDays(Aggregation op, Integer[] values, Object expectedObject) { + void testTimestampDays(ReductionAggregation op, Integer[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_DAYS, expectedObject); ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(values); Scalar result = v.reduce(op, expected.getType())) { @@ -335,7 +335,7 @@ void testTimestampDays(Aggregation op, Integer[] values, Object expectedObject) @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampSeconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampSeconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_SECONDS, expectedObject); ColumnVector v = ColumnVector.timestampSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -345,7 +345,7 @@ void testTimestampSeconds(Aggregation op, Long[] values, Object expectedObject) @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampMilliseconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampMilliseconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_MILLISECONDS, expectedObject); ColumnVector v = ColumnVector.timestampMilliSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -355,7 +355,7 @@ void testTimestampMilliseconds(Aggregation op, Long[] values, Object expectedObj @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampMicroseconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampMicroseconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_MICROSECONDS, expectedObject); ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { @@ -365,7 +365,7 @@ void testTimestampMicroseconds(Aggregation op, Long[] values, Object expectedObj @ParameterizedTest @MethodSource("createTimestampResolutionParams") - void testTimestampNanoseconds(Aggregation op, Long[] values, Object expectedObject) { + void testTimestampNanoseconds(ReductionAggregation op, Long[] values, Object expectedObject) { try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_NANOSECONDS, expectedObject); ColumnVector v = ColumnVector.timestampNanoSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 360f3c04f5b..1b2ed1ad0b8 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2902,12 +2902,12 @@ void testGroupByScan() { .withKeysSorted(true) .withKeysDescending(false, false) .build(), 0, 1) - .scan(Aggregation.sum().onColumn(2), - Aggregation.count(NullPolicy.INCLUDE).onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.max().onColumn(2), - Aggregation.rank().onColumn(3), - Aggregation.denseRank().onColumn(3)); + .scan(GroupByScanAggregation.sum().onColumn(2), + GroupByScanAggregation.count(NullPolicy.INCLUDE).onColumn(2), + GroupByScanAggregation.min().onColumn(2), + GroupByScanAggregation.max().onColumn(2), + GroupByScanAggregation.rank().onColumn(3), + GroupByScanAggregation.denseRank().onColumn(3)); Table expected = new Table.TestBuilder() .column( "1", "1", "1", "1", "1", "1", "1", "2", "2", "2", "2") .column( 0, 1, 3, 3, 5, 5, 5, 5, 5, 5, 5) @@ -2957,7 +2957,7 @@ void testGroupByUniqueCount() { .build()) { try (Table t3 = t1 .groupBy(0, 1) - .aggregate(Aggregation.nunique().onColumn(0)); + .aggregate(GroupByAggregation.nunique().onColumn(0)); Table sorted = t3.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() .column( "1", "1", "1", "1") @@ -2978,7 +2978,7 @@ void testGroupByUniqueCountNulls() { .build()) { try (Table t3 = t1 .groupBy(0, 1) - .aggregate(Aggregation.nunique(NullPolicy.INCLUDE).onColumn(0)); + .aggregate(GroupByAggregation.nunique(NullPolicy.INCLUDE).onColumn(0)); Table sorted = t3.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() .column( "1", "1", "1", "1") @@ -2997,7 +2997,7 @@ void testGroupByCount() { .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0) .build()) { try (Table t3 = t1.groupBy(0, 1) - .aggregate(Aggregation.count().onColumn(0)); + .aggregate(GroupByAggregation.count().onColumn(0)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -3048,9 +3048,9 @@ void testWindowingCount() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.count().onColumn(3).overWindow(window)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.count().onColumn(3).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); assertColumnsAreEqual(expect, decWindowAggResults.getColumn(0)); @@ -3088,9 +3088,9 @@ void testWindowingMin() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.min().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.min().onColumn(3).overWindow(window)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.min().onColumn(6).overWindow(window)); + .aggregateWindows(RollingAggregation.min().onColumn(6).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6); ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); @@ -3129,9 +3129,9 @@ void testWindowingMax() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.max().onColumn(3).overWindow(window)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.max().onColumn(6).overWindow(window)); + .aggregateWindows(RollingAggregation.max().onColumn(6).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); @@ -3163,7 +3163,7 @@ void testWindowingSum() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.sum().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.sum().onColumn(3).overWindow(window)); ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L)) { assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); } @@ -3199,12 +3199,12 @@ void testWindowingRowNumber() { WindowOptions options = windowBuilder.window(two, one).build(); WindowOptions options1 = windowBuilder.window(two, one).build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(3) .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(6) .overWindow(options1)); @@ -3219,12 +3219,12 @@ void testWindowingRowNumber() { WindowOptions options = windowBuilder.window(three, two).build(); WindowOptions options1 = windowBuilder.window(three, two).build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(3) .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(6) .overWindow(options1)); @@ -3239,12 +3239,12 @@ void testWindowingRowNumber() { WindowOptions options = windowBuilder.window(four, three).build(); WindowOptions options1 = windowBuilder.window(four, three).build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(3) .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .rowNumber() .onColumn(6) .overWindow(options1)); @@ -3259,8 +3259,8 @@ void testWindowingRowNumber() { @Test void testWindowingCollectList() { - Aggregation aggCollectWithNulls = Aggregation.collectList(NullPolicy.INCLUDE); - Aggregation aggCollect = Aggregation.collectList(); + RollingAggregation aggCollectWithNulls = RollingAggregation.collectList(NullPolicy.INCLUDE); + RollingAggregation aggCollect = RollingAggregation.collectList(); try (Scalar two = Scalar.fromInt(2); Scalar one = Scalar.fromInt(1); WindowOptions winOpts = WindowOptions.builder() @@ -3335,12 +3335,12 @@ void testWindowingCollectList() { @Test void testWindowingCollectSet() { - Aggregation aggCollect = Aggregation.collectSet(); - Aggregation aggCollectWithEqNulls = Aggregation.collectSet(NullPolicy.INCLUDE, + RollingAggregation aggCollect = RollingAggregation.collectSet(); + RollingAggregation aggCollectWithEqNulls = RollingAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL); - Aggregation aggCollectWithUnEqNulls = Aggregation.collectSet(NullPolicy.INCLUDE, + RollingAggregation aggCollectWithUnEqNulls = RollingAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); - Aggregation aggCollectWithEqNaNs = Aggregation.collectSet(NullPolicy.INCLUDE, + RollingAggregation aggCollectWithEqNaNs = RollingAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL); try (Scalar two = Scalar.fromInt(2); @@ -3473,22 +3473,22 @@ void testWindowingLead() { Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(two, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(0) .onColumn(3) // Int Agg Column .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(0) .onColumn(6) // Decimal Agg Column .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(0) .onColumn(7) // List Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(0) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3517,22 +3517,22 @@ void testWindowingLead() { Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(zero, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3575,22 +3575,22 @@ null, new StructData(13, "s13"), new StructData(14, "s14"), null, new StructData(-111, "s111"), new StructData(null, "s112"), new StructData(-222, "s222"), new StructData(-333, "s333")); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1, defaultOutput) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(1, decDefaultOutput) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1, listDefaultOutput) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(1, structDefaultOutput) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3619,22 +3619,22 @@ null, new StructData(13, "s13"), new StructData(14, "s14"), new StructData(-14, Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(zero, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(3) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lead(3) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(3) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lead(3) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3694,22 +3694,22 @@ void testWindowingLag() { Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(two, one).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(0) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(0) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(0) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(0) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3737,22 +3737,22 @@ void testWindowingLag() { Scalar two = Scalar.fromInt(2); WindowOptions options = windowBuilder.window(two, zero).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3794,22 +3794,22 @@ null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData( new StructData(-11, "s11"), null, new StructData(-13, "s13"), new StructData(-14, "s14"), new StructData(-111, "s111"), new StructData(null, "s112"), new StructData(-222, "s222"), new StructData(-333, "s333")); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1, defaultOutput) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(1, decDefaultOutput) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1, listDefaultOutput) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(1, structDefaultOutput) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3838,22 +3838,22 @@ null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData( Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(one, zero).build(); Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(3) .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation + .aggregateWindows(RollingAggregation .lag(3) .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(3) .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( - Aggregation + RollingAggregation .lag(3) .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); @@ -3896,7 +3896,7 @@ void testWindowingMean() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.mean().onColumn(3).overWindow(window)); + .aggregateWindows(RollingAggregation.mean().onColumn(3).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedDoubles(6.0d, 5.0d, 5.0d, 5.0d, 8.0d, 8.0d, 7.0d, 6.0d, 4.0d, 4.0d, 4.0d, 6.0d)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -3941,10 +3941,10 @@ void testWindowingOnMultipleDifferentColumns() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows( - Aggregation.sum().onColumn(3).overWindow(window_1), - Aggregation.max().onColumn(3).overWindow(window_1), - Aggregation.sum().onColumn(3).overWindow(window_2), - Aggregation.min().onColumn(2).overWindow(window_3) + RollingAggregation.sum().onColumn(3).overWindow(window_1), + RollingAggregation.max().onColumn(3).overWindow(window_1), + RollingAggregation.sum().onColumn(3).overWindow(window_2), + RollingAggregation.min().onColumn(2).overWindow(window_3) ); ColumnVector expect_0 = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L); ColumnVector expect_1 = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); @@ -3979,8 +3979,8 @@ void testWindowingWithoutGroupByColumns() { .build()) { try (Table windowAggResults = sorted.groupBy().aggregateWindows( - Aggregation.sum().onColumn(1).overWindow(window)); - ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L); + RollingAggregation.sum().onColumn(1).overWindow(window)); + ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L) ) { assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); } @@ -4054,7 +4054,7 @@ void testRangeWindowingCount() { .orderByColumnIndex(orderIndex) .build()) { try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(window)); + RollingAggregation.count().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 2, 4, 4, 4, 4, 4, 4, 5, 5, 3)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4098,7 +4098,7 @@ void testRangeWindowingLead() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lead(1) + .aggregateWindowsOverRanges(RollingAggregation.lead(1) .onColumn(2) .overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { @@ -4144,7 +4144,7 @@ void testRangeWindowingMax() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.max().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 9, 8, 8, 8, 8, 8)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4158,7 +4158,7 @@ void testRangeWindowingMax() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(2).overWindow(window)); + .aggregateWindows(RollingAggregation.max().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 8, 8)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4202,7 +4202,7 @@ void testRangeWindowingRowNumber() { .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.rowNumber().onColumn(2).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.rowNumber().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4254,12 +4254,12 @@ void testRangeWindowingCountDescendingTimestamps() { .window(preceding_1, following_1) .orderByColumnIndex(orderIndex) .orderByDescending() - .build();) { + .build()) { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(window_0), - Aggregation.sum().onColumn(2).overWindow(window_1)); + RollingAggregation.count().onColumn(2).overWindow(window_0), + RollingAggregation.sum().onColumn(2).overWindow(window_1)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 5, 5, 5); ColumnVector expect_1 = ColumnVector.fromBoxedLongs(7L, 13L, 13L, 22L, 7L, 24L, 24L, 26L, 8L, 8L, 14L, 28L, 28L)) { assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); @@ -4303,7 +4303,7 @@ void testRangeWindowingWithoutGroupByColumns() { .build();) { try (Table windowAggResults = sorted.groupBy() - .aggregateWindowsOverRanges(Aggregation.count().onColumn(1).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.count().onColumn(1).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 6, 6, 6, 6, 7, 7, 6, 6, 5, 5, 3)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4333,7 +4333,7 @@ void testRangeWindowingOrderByUnsupportedDataTypeExceptions() { assertThrows(IllegalArgumentException.class, () -> table .groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(rangeBasedWindow))); + .aggregateWindowsOverRanges(RollingAggregation.max().onColumn(2).overWindow(rangeBasedWindow))); } } } @@ -4353,7 +4353,7 @@ void testInvalidWindowTypeExceptions() { .minPeriods(1) .window(one, one) .build()) { - assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindowsOverRanges(Aggregation.max().onColumn(3).overWindow(rowBasedWindow))); + assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindowsOverRanges(RollingAggregation.max().onColumn(3).overWindow(rowBasedWindow))); } try (WindowOptions rangeBasedWindow = WindowOptions.builder() @@ -4361,7 +4361,7 @@ void testInvalidWindowTypeExceptions() { .window(one, one) .orderByColumnIndex(2) .build()) { - assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindows(Aggregation.max().onColumn(3).overWindow(rangeBasedWindow))); + assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindows(RollingAggregation.max().onColumn(3).overWindow(rangeBasedWindow))); } } } @@ -4399,7 +4399,7 @@ void testRangeWindowingCountUnboundedPreceding() { .build();) { try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.count().onColumn(2).overWindow(window)); + .aggregateWindowsOverRanges(RollingAggregation.count().onColumn(2).overWindow(window)); ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } @@ -4475,11 +4475,11 @@ void testRangeWindowingCountUnboundedASCWithNullsFirst() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 5, 5, 6, 2, 2, 4, 4, 6, 6, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 3, 1, 7, 7, 5, 5, 3, 3, 1); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); @@ -4570,11 +4570,11 @@ void testRangeWindowingCountUnboundedDESCWithNullsFirst() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 6, 6, 2, 2, 3, 5, 5, 7, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 2, 7, 7, 5, 4, 4, 2, 2); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); @@ -4658,11 +4658,11 @@ void testRangeWindowingCountUnboundedASCWithNullsLast() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(2, 2, 3, 6, 6, 6, 2, 2, 4, 4, 5, 7, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 4, 3, 3, 3, 7, 7, 5, 5, 3, 2, 2); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); @@ -4752,11 +4752,11 @@ void testRangeWindowingCountUnboundedDESCWithNullsLast() { try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + RollingAggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); ColumnVector expect_0 = ColumnVector.fromBoxedInts(1, 3, 3, 6, 6, 6, 1, 3, 3, 5, 5, 7, 7); ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 5, 5, 3, 3, 3, 7, 6, 6, 4, 4, 2, 2); ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); @@ -4784,9 +4784,9 @@ void testGroupByCountWithNulls() { .column( 1, 1, 1, null, 1, 1) .build()) { try (Table tmp = t1.groupBy(0).aggregate( - Aggregation.count().onColumn(1), - Aggregation.count().onColumn(2), - Aggregation.count().onColumn(3)); + GroupByAggregation.count().onColumn(1), + GroupByAggregation.count().onColumn(2), + GroupByAggregation.count().onColumn(3)); Table t3 = tmp.orderBy(OrderByArg.asc(0, true)); HostColumnVector groupCol = t3.getColumn(0).copyToHost(); HostColumnVector countCol = t3.getColumn(1).copyToHost(); @@ -4824,10 +4824,10 @@ void testGroupByCountWithNullsIncluded() { .column( 1, 1, 1, null, 1, 1) .build()) { try (Table tmp = t1.groupBy(0).aggregate( - Aggregation.count(NullPolicy.INCLUDE).onColumn(1), - Aggregation.count(NullPolicy.INCLUDE).onColumn(2), - Aggregation.count(NullPolicy.INCLUDE).onColumn(3), - Aggregation.count().onColumn(3)); + GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(1), + GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(2), + GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(3), + GroupByAggregation.count().onColumn(3)); Table t3 = tmp.orderBy(OrderByArg.asc(0, true)); HostColumnVector groupCol = t3.getColumn(0).copyToHost(); HostColumnVector countCol = t3.getColumn(1).copyToHost(); @@ -4875,9 +4875,9 @@ void testGroupByCountWithCollapsingNulls() { .build(); try (Table tmp = t1.groupBy(options, 0).aggregate( - Aggregation.count().onColumn(1), - Aggregation.count().onColumn(2), - Aggregation.count().onColumn(3)); + GroupByAggregation.count().onColumn(1), + GroupByAggregation.count().onColumn(2), + GroupByAggregation.count().onColumn(3)); Table t3 = tmp.orderBy(OrderByArg.asc(0, true)); HostColumnVector groupCol = t3.getColumn(0).copyToHost(); HostColumnVector countCol = t3.getColumn(1).copyToHost(); @@ -4908,7 +4908,7 @@ void testGroupByMax() { .column( 1, 3, 3, 5, 5, 0) .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.max().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.max().onColumn(2)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -4943,7 +4943,7 @@ void testGroupByArgMax() { .column(17.0, 14.0, 14.0, 17.0, 17.1, 17.0) .build()) { try (Table t3 = t1.groupBy(0, 1) - .aggregate(Aggregation.argMax().onColumn(2)); + .aggregate(GroupByAggregation.argMax().onColumn(2)); Table sorted = t3 .orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() @@ -4965,7 +4965,7 @@ void testGroupByArgMin() { .column(17.0, 14.0, 14.0, 17.0, 17.1, 17.0) .build()) { try (Table t3 = t1.groupBy(0, 1) - .aggregate(Aggregation.argMin().onColumn(2)); + .aggregate(GroupByAggregation.argMin().onColumn(2)); Table sorted = t3 .orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); Table expected = new Table.TestBuilder() @@ -4983,7 +4983,7 @@ void testGroupByMinBool() { try (Table t1 = new Table.TestBuilder() .column(true, null, false, true, null, null) .column( 1, 1, 2, 2, 3, 3).build(); - Table other = t1.groupBy(1).aggregate(Aggregation.min().onColumn(0)); + Table other = t1.groupBy(1).aggregate(GroupByAggregation.min().onColumn(0)); Table ordered = other.orderBy(OrderByArg.asc(0)); Table expected = new Table.TestBuilder() .column(1, 2, 3) @@ -4998,7 +4998,7 @@ void testGroupByMaxBool() { try (Table t1 = new Table.TestBuilder() .column(false, null, false, true, null, null) .column( 1, 1, 2, 2, 3, 3).build(); - Table other = t1.groupBy(1).aggregate(Aggregation.max().onColumn(0)); + Table other = t1.groupBy(1).aggregate(GroupByAggregation.max().onColumn(0)); Table ordered = other.orderBy(OrderByArg.asc(0)); Table expected = new Table.TestBuilder() .column(1, 2, 3) @@ -5025,12 +5025,12 @@ void testGroupByDuplicateAggregates() { .column( 1, 2, 2, 1).build()) { try (Table t3 = t1.groupBy(0, 1) .aggregate( - Aggregation.max().onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.max().onColumn(2), - Aggregation.min().onColumn(2), - Aggregation.count().onColumn(1)); + GroupByAggregation.max().onColumn(2), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.max().onColumn(2), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.count().onColumn(1)); Table t4 = t3.orderBy(OrderByArg.asc(2))) { // verify t4 assertEquals(4, t4.getRowCount()); @@ -5053,7 +5053,7 @@ void testGroupByMin() { .column( 1, 3, 3, 5, 5, 0) .column( 12, 14, 13, 17, 17, 17) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.min().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.min().onColumn(2)); HostColumnVector aggOut0 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -5088,7 +5088,7 @@ void testGroupBySum() { .column( 1, 3, 3, 5, 5, 0) .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.sum().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.sum().onColumn(2)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -5121,7 +5121,7 @@ void testGroupByM2() { try (Table input = new Table.TestBuilder().column(1, 2, 3, 1, 2, 2, 1, 3, 3, 2) .column(0, 1, -2, 3, -4, -5, -6, 7, -8, 9) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3) .column(42.0, 122.75, 114.0) @@ -5134,7 +5134,7 @@ void testGroupByM2() { try (Table input = new Table.TestBuilder().column(1, 2, 5, 3, 4, 5, 2, 3, 2, 5) .column(0, null, null, 2, 3, null, 5, 6, 7, null) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3, 4, 5) .column(0.0, 2.0, 8.0, 0.0, null) @@ -5146,7 +5146,7 @@ void testGroupByM2() { try (Table input = new Table.TestBuilder().column(4, 3, 1, 2, 3, 1, 2, 2, 1, null, 3, 2, 4, 4) .column(null, null, 0.0, 1.0, 2.0, 3.0, 4.0, Double.NaN, 6.0, 7.0, 8.0, 9.0, 10.0, Double.NaN) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3, 4, null) .column(18.0, Double.NaN, 18.0, Double.NaN, 0.0) @@ -5179,7 +5179,7 @@ void testGroupByM2() { Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY) .build(); - Table results = input.groupBy(0).aggregate(Aggregation.M2() + Table results = input.groupBy(0).aggregate(GroupByAggregation.M2() .onColumn(1)); Table expected = new Table.TestBuilder().column(1, 2, 3, 4, 5) .column(Double.NaN, Double.NaN, Double.NaN, Double.NaN, 12.5) @@ -5237,7 +5237,7 @@ void testGroupByMergeM2() { partialResults3, partialResults4); Table finalResults = concatenatedResults.groupBy(0).aggregate( - Aggregation.mergeM2().onColumn(1)) + GroupByAggregation.mergeM2().onColumn(1)) ) { assertTablesAreEqual(expected, finalResults); } @@ -5255,7 +5255,7 @@ void testGroupByFirstExcludeNulls() { .column(13, 14) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(0, NullPolicy.EXCLUDE).onColumn(1))) { + GroupByAggregation.nth(0, NullPolicy.EXCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5271,7 +5271,7 @@ void testGroupByLastExcludeNulls() { .column(12, 15) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(-1, NullPolicy.EXCLUDE).onColumn(1))) { + GroupByAggregation.nth(-1, NullPolicy.EXCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5287,7 +5287,7 @@ void testGroupByFirstIncludeNulls() { .column(null, 14) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(0, NullPolicy.INCLUDE).onColumn(1))) { + GroupByAggregation.nth(0, NullPolicy.INCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5303,7 +5303,7 @@ void testGroupByLastIncludeNulls() { .column(12, null) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.nth(-1, NullPolicy.INCLUDE).onColumn(1))) { + GroupByAggregation.nth(-1, NullPolicy.INCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5314,7 +5314,7 @@ void testGroupByAvg() { .column( 1, 3, 3, 5, 5, 0) .column(12, 14, 13, 1, 17, 17) .build()) { - try (Table t3 = t1.groupBy(0, 1).aggregate(Aggregation.mean().onColumn(2)); + try (Table t3 = t1.groupBy(0, 1).aggregate(GroupByAggregation.mean().onColumn(2)); HostColumnVector aggOut1 = t3.getColumn(2).copyToHost()) { // verify t3 assertEquals(4, t3.getRowCount()); @@ -5349,11 +5349,11 @@ void testMultiAgg() { .column( 3, 1, 7, -1, 9, 0) .build()) { try (Table t2 = t1.groupBy(0, 1).aggregate( - Aggregation.count().onColumn(0), - Aggregation.max().onColumn(3), - Aggregation.min().onColumn(2), - Aggregation.mean().onColumn(2), - Aggregation.sum().onColumn(2)); + GroupByAggregation.count().onColumn(0), + GroupByAggregation.max().onColumn(3), + GroupByAggregation.min().onColumn(2), + GroupByAggregation.mean().onColumn(2), + GroupByAggregation.sum().onColumn(2)); HostColumnVector countOut = t2.getColumn(2).copyToHost(); HostColumnVector maxOut = t2.getColumn(3).copyToHost(); HostColumnVector minOut = t2.getColumn(4).copyToHost(); @@ -5419,7 +5419,7 @@ void testSumWithStrings() { .column(5289L, 5203L, 5303L, 5206L) .build(); Table result = t.groupBy(0).aggregate( - Aggregation.sum().onColumn(1)); + GroupByAggregation.sum().onColumn(1)); Table expected = new Table.TestBuilder() .column("1-URGENT", "3-MEDIUM") .column(5289L + 5303L, 5203L + 5206L) @@ -5517,7 +5517,7 @@ void testGroupByCollectListIncludeNulls() { Arrays.asList(0)) .build(); Table found = input.groupBy(0).aggregate( - Aggregation.collectList(NullPolicy.INCLUDE).onColumn(1))) { + GroupByAggregation.collectList(NullPolicy.INCLUDE).onColumn(1))) { assertTablesAreEqual(expected, found); } } @@ -5563,8 +5563,8 @@ void testGroupByMergeLists() { Arrays.asList(new StructData(333, "s333"), new StructData(222, "s222"), new StructData(111, "s111")), Arrays.asList(new StructData(222, "s222"), new StructData(444, "s444"))) .build(); - Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(1)); - Table retListOfStructs = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(2))) { + Table retListOfInts = input.groupBy(0).aggregate(GroupByAggregation.mergeLists().onColumn(1)); + Table retListOfStructs = input.groupBy(0).aggregate(GroupByAggregation.mergeLists().onColumn(2))) { assertTablesAreEqual(expectedListOfInts, retListOfInts); assertTablesAreEqual(expectedListOfStructs, retListOfStructs); } @@ -5573,7 +5573,7 @@ void testGroupByMergeLists() { @Test void testGroupByCollectSetIncludeNulls() { // test with null unequal and nan unequal - Aggregation collectSet = Aggregation.collectSet(NullPolicy.INCLUDE, + GroupByAggregation collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) @@ -5589,7 +5589,7 @@ void testGroupByCollectSetIncludeNulls() { assertTablesAreEqual(expected, found); } // test with null equal and nan unequal - collectSet = Aggregation.collectSet(NullPolicy.INCLUDE, + collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) @@ -5610,7 +5610,7 @@ void testGroupByCollectSetIncludeNulls() { assertTablesAreEqual(expected, found); } // test with null equal and nan equal - collectSet = Aggregation.collectSet(NullPolicy.INCLUDE, + collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) @@ -5671,10 +5671,10 @@ void testGroupByMergeSets() { Arrays.asList(1e-3, 1e3, Double.NaN), Arrays.asList()) .build(); - Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(1)); - Table retListOfDoubles = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(2)); + Table retListOfInts = input.groupBy(0).aggregate(GroupByAggregation.mergeSets().onColumn(1)); + Table retListOfDoubles = input.groupBy(0).aggregate(GroupByAggregation.mergeSets().onColumn(2)); Table retListOfDoublesNaNEq = input.groupBy(0).aggregate( - Aggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2))) { + GroupByAggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2))) { assertTablesAreEqual(expectedListOfInts, retListOfInts); assertTablesAreEqual(expectedListOfDoubles, retListOfDoubles); assertTablesAreEqual(expectedListOfDoublesNaNEq, retListOfDoublesNaNEq);