diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 734d9cb5694..1d73bd71246 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -65,7 +65,9 @@ enum Kind { M2(26), MERGE_M2(27), RANK(28), - DENSE_RANK(29); + DENSE_RANK(29), + TDIGEST(30), // This can take a delta argument for accuracy level + MERGE_TDIGEST(31); // This can take a delta argument for accuracy level final int nativeId; @@ -864,6 +866,44 @@ static MergeM2Aggregation mergeM2() { return new MergeM2Aggregation(); } + static class TDigestAggregation extends Aggregation { + private final int delta; + + public TDigestAggregation(Kind kind, int delta) { + super(kind); + this.delta = delta; + } + + @Override + long createNativeInstance() { + return Aggregation.createTDigestAgg(kind.nativeId, delta); + } + + @Override + public int hashCode() { + return 31 * kind.hashCode() + delta; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof TDigestAggregation) { + TDigestAggregation o = (TDigestAggregation) other; + return o.delta == this.delta; + } + return false; + } + } + + static TDigestAggregation createTDigest(int delta) { + return new TDigestAggregation(Kind.TDIGEST, delta); + } + + static TDigestAggregation mergeTDigest(int delta) { + return new TDigestAggregation(Kind.MERGE_TDIGEST, delta); + } + /** * Create one of the aggregations that only needs a kind, no other parameters. This does not * work for all types and for code safety reasons each kind is added separately. @@ -909,4 +949,9 @@ static MergeM2Aggregation mergeM2() { * Create a merge sets aggregation. */ private static native long createMergeSetsAgg(boolean nullsEqual, boolean nansEqual); + + /** + * Create a TDigest aggregation. + */ + private static native long createTDigestAgg(int kind, int delta); } diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 53a02d83dd1..ad081dc7614 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1423,12 +1423,39 @@ public Scalar reduce(ReductionAggregation aggregation, DType outType) { } } + /** + * Calculate various percentiles of this ColumnVector, which must contain centroids produced by + * a t-digest aggregation. + * + * @param percentiles Required percentiles [0,1] + * @return Column containing the approximate percentile values as a list of doubles, in + * the same order as the input percentiles + */ + public final ColumnVector approxPercentile(double[] percentiles) { + try (ColumnVector cv = ColumnVector.fromDoubles(percentiles)) { + return approxPercentile(cv); + } + } + + /** + * Calculate various percentiles of this ColumnVector, which must contain centroids produced by + * a t-digest aggregation. + * + * @param percentiles Column containing percentiles [0,1] + * @return Column containing the approximate percentile values as a list of doubles, in + * the same order as the input percentiles + */ + public final ColumnVector approxPercentile(ColumnVector percentiles) { + return new ColumnVector(approxPercentile(getNativeView(), percentiles.getNativeView())); + } + /** * Calculate various quantiles of this ColumnVector. It is assumed that this is already sorted * in the desired order. * @param method the method used to calculate the quantiles * @param quantiles the quantile values [0,1] - * @return the quantiles as doubles, in the same order passed in. The type can be changed in future + * @return Column containing the approximate percentile values as a list of doubles, in + * the same order as the input percentiles */ public final ColumnVector quantile(QuantileMethod method, double[] quantiles) { return new ColumnVector(quantile(getNativeView(), method.nativeId, quantiles)); @@ -3544,6 +3571,15 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat */ private static native long upperStrings(long cudfViewHandle); + /** + * Native method to compute approx percentiles. + * @param cudfColumnHandle T-Digest column + * @param percentilesHandle Percentiles + * @return native handle of the resulting cudf column, used to construct the Java column + * by the approxPercentile method. + */ + private static native long approxPercentile(long cudfColumnHandle, long percentilesHandle) throws CudfException; + private static native long quantile(long cudfColumnHandle, int quantileMethod, double[] quantiles) throws CudfException; private static native long rollingWindow( diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java index dd2adf8bee8..682d844c43c 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -293,4 +293,26 @@ public static GroupByAggregation mergeSets(NullEquality nullEquality, NaNEqualit public static GroupByAggregation mergeM2() { return new GroupByAggregation(Aggregation.mergeM2()); } + + /** + * Compute a t-digest from on a fixed-width numeric input column. + * + * @param delta Required accuracy (number of buckets). + * @return A list of centroids per grouping, where each centroid has a mean value and a + * weight. The number of centroids will be <= delta. + */ + public static GroupByAggregation createTDigest(int delta) { + return new GroupByAggregation(Aggregation.createTDigest(delta)); + } + + /** + * Merge t-digests. + * + * @param delta Required accuracy (number of buckets). + * @return A list of centroids per grouping, where each centroid has a mean value and a + * weight. The number of centroids will be <= delta. + */ + public static GroupByAggregation mergeTDigest(int delta) { + return new GroupByAggregation(Aggregation.mergeTDigest(delta)); + } } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index 0d180a55583..93a01854ced 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -85,7 +85,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv return cudf::make_rank_aggregation(); case 29: // DENSE_RANK return cudf::make_dense_rank_aggregation(); - default: throw std::logic_error("Unsupported No Parameter Aggregation Operation"); } }(); @@ -131,6 +130,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createDdofAgg(JNIEnv *en CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createTDigestAgg(JNIEnv *env, + jclass class_object, + jint kind, jint delta) { + try { + cudf::jni::auto_set_device(env); + + std::unique_ptr ret; + // These numbers come from Aggregation.java and must stay in sync + switch (kind) { + case 30: // TDIGEST + ret = cudf::make_tdigest_aggregation(delta); + break; + case 31: // MERGE_TDIGEST + ret = cudf::make_merge_tdigest_aggregation(delta); + break; + default: throw std::logic_error("Unsupported TDigest Aggregation Operation"); + } + return reinterpret_cast(ret.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCountLikeAgg(JNIEnv *env, jclass class_object, jint kind, diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index adc0de12f25..cd5ff073edd 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -289,6 +289,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_scan(JNIEnv *env, jclass, CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_approxPercentile(JNIEnv *env, jclass clazz, + jlong input_column, + jlong percentiles_column) { + JNI_NULL_CHECK(env, input_column, "input_column native handle is null", 0); + JNI_NULL_CHECK(env, percentiles_column, "percentiles_column native handle is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *n_input_column = reinterpret_cast(input_column); + std::unique_ptr input_view = + std::make_unique(*n_input_column); + cudf::column_view *n_percentiles_column = + reinterpret_cast(percentiles_column); + std::unique_ptr result = + cudf::percentile_approx(*input_view, *n_percentiles_column); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_quantile(JNIEnv *env, jclass clazz, jlong input_column, jint quantile_method, diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 0e7ac15a79e..aa9ef5bf766 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -3484,6 +3484,106 @@ void testGroupByReplaceNulls() { } } + @Test + void testGroupByApproxPercentileReproCase() { + double[] percentiles = {0.25, 0.50, 0.75}; + try (Table t1 = new Table.TestBuilder() + .column("a", "a", "b", "c", "d") + .column(1084.0, 1719.0, 15948.0, 148029.0, 1269761.0) + .build(); + Table t2 = t1 + .groupBy(0) + .aggregate(GroupByAggregation.createTDigest(100).onColumn(1)); + Table sorted = t2.orderBy(OrderByArg.asc(0)); + ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.FLOAT64)), + Arrays.asList(1084.0, 1084.0, 1719.0), + Arrays.asList(15948.0, 15948.0, 15948.0), + Arrays.asList(148029.0, 148029.0, 148029.0), + Arrays.asList(1269761.0, 1269761.0, 1269761.0) + )) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testGroupByApproxPercentile() { + double[] percentiles = {0.25, 0.50, 0.75}; + try (Table t1 = new Table.TestBuilder() + .column("a", "a", "a", "b", "b", "b") + .column(100, 150, 160, 70, 110, 160) + .build(); + Table t2 = t1 + .groupBy(0) + .aggregate(GroupByAggregation.createTDigest(1000).onColumn(1)); + Table sorted = t2.orderBy(OrderByArg.asc(0)); + ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.FLOAT64)), + Arrays.asList(100d, 150d, 160d), + Arrays.asList(70d, 110d, 160d) + )) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testMergeApproxPercentile() { + double[] percentiles = {0.25, 0.50, 0.75}; + try (Table t1 = new Table.TestBuilder() + .column("a", "a", "a", "b", "b", "b") + .column(100, 150, 160, 70, 110, 160) + .build(); + Table t2 = t1 + .groupBy(0) + .aggregate(GroupByAggregation.createTDigest(1000).onColumn(1)); + Table t3 = t1 + .groupBy(0) + .aggregate(GroupByAggregation.createTDigest(1000).onColumn(1)); + Table t4 = Table.concatenate(t2, t3); + Table t5 = t4 + .groupBy(0) + .aggregate(GroupByAggregation.mergeTDigest(1000).onColumn(1)); + Table sorted = t5.orderBy(OrderByArg.asc(0)); + ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.FLOAT64)), + Arrays.asList(100d, 150d, 160d), + Arrays.asList(70d, 110d, 160d) + )) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testMergeApproxPercentile2() { + double[] percentiles = {0.25, 0.50, 0.75}; + try (Table t1 = new Table.TestBuilder() + .column("a", "a", "a", "b", "b", "b") + .column(70, 110, 160, 100, 150, 160) + .build(); + Table t2 = t1 + .groupBy(0) + .aggregate(GroupByAggregation.createTDigest(1000).onColumn(1)); + Table t3 = t1 + .groupBy(0) + .aggregate(GroupByAggregation.createTDigest(1000).onColumn(1)); + Table t4 = Table.concatenate(t2, t3); + Table t5 = t4 + .groupBy(0) + .aggregate(GroupByAggregation.mergeTDigest(1000).onColumn(1)); + Table sorted = t5.orderBy(OrderByArg.asc(0)); + ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.FLOAT64)), + Arrays.asList(70d, 110d, 160d), + Arrays.asList(100d, 150d, 160d) + )) { + assertColumnsAreEqual(expected, actual); + } + } + @Test void testGroupByUniqueCount() { try (Table t1 = new Table.TestBuilder()