diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index b8920cc59eb..7d8989571f7 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,6 +65,18 @@ enum Kind { Kind(int nativeId) {this.nativeId = nativeId;} } + /* + * This is analogous to the native 'null_policy'. + */ + public enum NullPolicy { + EXCLUDE(false), + INCLUDE(true); + + NullPolicy(boolean includeNulls) { this.includeNulls = includeNulls; } + + final boolean includeNulls; + } + /** * An Aggregation that only needs a kind and nothing else. */ @@ -97,22 +109,22 @@ public boolean equals(Object other) { private static final class NthAggregation extends Aggregation { private final int offset; - private final boolean includeNulls; + private final NullPolicy nullPolicy; - public NthAggregation(int offset, boolean includeNulls) { + public NthAggregation(int offset, NullPolicy nullPolicy) { super(Kind.NTH_ELEMENT); this.offset = offset; - this.includeNulls = includeNulls; + this.nullPolicy = nullPolicy; } @Override long createNativeInstance() { - return Aggregation.createNthAgg(offset, includeNulls); + return Aggregation.createNthAgg(offset, nullPolicy.includeNulls); } @Override public int hashCode() { - return 31 * offset + Boolean.hashCode(includeNulls); + return 31 * offset + nullPolicy.hashCode(); } @Override @@ -121,7 +133,7 @@ public boolean equals(Object other) { return true; } else if (other instanceof NthAggregation) { NthAggregation o = (NthAggregation) other; - return o.offset == this.offset && o.includeNulls == this.includeNulls; + return o.offset == this.offset && o.nullPolicy == this.nullPolicy; } return false; } @@ -158,21 +170,21 @@ public boolean equals(Object other) { } private static final class CountLikeAggregation extends Aggregation { - private final boolean includeNulls; + private final NullPolicy nullPolicy; - public CountLikeAggregation(Kind kind, boolean includeNulls) { + public CountLikeAggregation(Kind kind, NullPolicy nullPolicy) { super(kind); - this.includeNulls = includeNulls; + this.nullPolicy = nullPolicy; } @Override long createNativeInstance() { - return Aggregation.createCountLikeAgg(kind.nativeId, includeNulls); + return Aggregation.createCountLikeAgg(kind.nativeId, nullPolicy.includeNulls); } @Override public int hashCode() { - return 31 * kind.hashCode() + Boolean.hashCode(includeNulls); + return 31 * kind.hashCode() + nullPolicy.hashCode(); } @Override @@ -181,7 +193,7 @@ public boolean equals(Object other) { return true; } else if (other instanceof CountLikeAggregation) { CountLikeAggregation o = (CountLikeAggregation) other; - return o.includeNulls == this.includeNulls; + return o.nullPolicy == this.nullPolicy; } return false; } @@ -268,6 +280,36 @@ long getDefaultOutput() { } } + private static final class CollectAggregation extends Aggregation { + private final NullPolicy nullPolicy; + + public CollectAggregation(NullPolicy nullPolicy) { + super(Kind.COLLECT); + this.nullPolicy = nullPolicy; + } + + @Override + long createNativeInstance() { + return Aggregation.createCollectAgg(nullPolicy.includeNulls); + } + + @Override + public int hashCode() { + return 31 * kind.hashCode() + nullPolicy.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof CollectAggregation) { + CollectAggregation o = (CollectAggregation) other; + return o.nullPolicy == this.nullPolicy; + } + return false; + } + } + protected final Kind kind; protected Aggregation(Kind kind) { @@ -351,16 +393,27 @@ public static Aggregation max() { * Count number of valid, a.k.a. non-null, elements. */ public static Aggregation count() { - return count(false); + return count(NullPolicy.EXCLUDE); } /** * Count number of elements. + * (This is deprecated, use {@link Aggregation#count(NullPolicy nullPolicy)} instead) * @param includeNulls true if nulls should be counted. false if only non-null values should be * counted. */ + @Deprecated public static Aggregation count(boolean includeNulls) { - return new CountLikeAggregation(Kind.COUNT, includeNulls); + return count(includeNulls ? NullPolicy.INCLUDE : NullPolicy.EXCLUDE); + } + + /** + * Count number of elements. + * @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values + * should be counted. + */ + public static Aggregation count(NullPolicy nullPolicy) { + return new CountLikeAggregation(Kind.COUNT, nullPolicy); } /** @@ -473,17 +526,29 @@ public static Aggregation argMin() { * Number of unique, non-null, elements. */ public static Aggregation nunique() { - return nunique(false); + return nunique(NullPolicy.EXCLUDE); } /** * Number of unique elements. + * (This is deprecated, use {@link Aggregation#nunique(NullPolicy nullPolicy)} instead) * @param includeNulls true if nulls should be counted else false. If nulls are counted they * compare as equal so multiple null values in a range would all only * increase the count by 1. */ + @Deprecated public static Aggregation nunique(boolean includeNulls) { - return new CountLikeAggregation(Kind.NUNIQUE, includeNulls); + return nunique(includeNulls ? NullPolicy.INCLUDE : NullPolicy.EXCLUDE); + } + + /** + * Number of unique elements. + * @param nullPolicy INCLUDE if nulls should be counted else EXCLUDE. If nulls are counted they + * compare as equal so multiple null values in a range would all only + * increase the count by 1. + */ + public static Aggregation nunique(NullPolicy nullPolicy) { + return new CountLikeAggregation(Kind.NUNIQUE, nullPolicy); } /** @@ -492,18 +557,31 @@ public static Aggregation nunique(boolean includeNulls) { * value outside of the group range results in a null. */ public static Aggregation nth(int offset) { - return nth(offset, true); + return nth(offset, NullPolicy.INCLUDE); } /** * Get the nth element in a group. + * (This is deprecated, use {@link Aggregation#nth(int offset, NullPolicy nullPolicy)} instead) * @param offset the offset to look at. Negative numbers go from the end of the group. Any * value outside of the group range results in a null. * @param includeNulls true if nulls should be included in the aggregation or false if they * should be skipped. */ + @Deprecated public static Aggregation nth(int offset, boolean includeNulls) { - return new NthAggregation(offset, includeNulls); + return nth(offset, includeNulls ? NullPolicy.INCLUDE : NullPolicy.EXCLUDE); + } + + /** + * Get the nth element in a group. + * @param offset the offset to look at. Negative numbers go from the end of the group. Any + * value outside of the group range results in a null. + * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they + * should be skipped. + */ + public static Aggregation nth(int offset, NullPolicy nullPolicy) { + return new NthAggregation(offset, nullPolicy); } /** @@ -514,10 +592,19 @@ public static Aggregation rowNumber() { } /** - * Collect the values into a list. + * Collect the values into a list. nulls will be skipped. */ public static Aggregation collect() { - return new NoParamAggregation(Kind.COLLECT); + return collect(NullPolicy.EXCLUDE); + } + + /** + * Collect the values into a list. + * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they + * should be skipped. + */ + public static Aggregation collect(NullPolicy nullPolicy) { + return new CollectAggregation(nullPolicy); } /** @@ -586,4 +673,9 @@ public static Aggregation lag(int offset, ColumnVector defaultOutput) { * Create a lead or lag aggregation. */ private static native long createLeadLagAgg(int kind, int offset); + + /** + * Create a collect aggregation including nulls or not. + */ + private static native long createCollectAgg(boolean includeNulls); } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index 974a70a7683..aae7cb493a8 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,9 +81,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv case 17: // ROW_NUMBER ret = cudf::make_row_number_aggregation(); break; - case 18: // COLLECT - ret = cudf::make_collect_aggregation(); - break; + // case 18: COLLECT // case 19: LEAD // case 20: LAG // case 21: PTX @@ -201,4 +199,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectAgg(JNIEnv *env, + jclass class_object, + jboolean include_nulls) { + try { + cudf::jni::auto_set_device(env); + cudf::null_policy policy = + include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE; + std::unique_ptr ret = cudf::make_collect_aggregation(policy); + return reinterpret_cast(ret.release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 02942846cd6..88196a4112a 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2503,6 +2503,7 @@ void testWindowingRowNumber() { @Test void testWindowingCollect() { + Aggregation aggCollectWithNulls = Aggregation.collect(Aggregation.NullPolicy.INCLUDE); Aggregation aggCollect = Aggregation.collect(); WindowOptions winOpts = WindowOptions.builder() .minPeriods(1) @@ -2513,26 +2514,38 @@ void testWindowingCollect() { .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key .column( 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8) // OBY Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column of INT32 + .column( 7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null) // Agg Column of INT32 .column(nestedType, // Agg Column of Struct new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3"), new StructData(4, "s4"), new StructData(11, "s11"), new StructData(22, "s22"), new StructData(33, "s33"), new StructData(44, "s44"), new StructData(111, "s111"), new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444") ).build(); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { + ColumnVector expectSortedAggColumn = ColumnVector + .fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null)) { try (Table sorted = raw.orderBy(Table.asc(0), Table.asc(1), Table.asc(2))) { ColumnVector sortedAggColumn = sorted.getColumn(3); assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); // Primitive type: INT32 + // a) including nulls + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollectWithNulls.onColumn(3).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.INT32)), + Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9), + Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2), + Arrays.asList(null,0), Arrays.asList(null,0,6), Arrays.asList(0,6,null), Arrays.asList(6,null))) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } + // b) excluding nulls try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts)); ColumnVector expected = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.INT32)), Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9), Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2), - Arrays.asList(8,0), Arrays.asList(8,0,6), Arrays.asList(0,6,6), Arrays.asList(6,6))) { + Arrays.asList(0), Arrays.asList(0,6), Arrays.asList(0,6), Arrays.asList(6))) { assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); }