diff --git a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java index 9147d6763ac..eab1c94fd2c 100644 --- a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java @@ -211,6 +211,20 @@ public static ReductionAggregation nth(int offset, NullPolicy nullPolicy) { } /** + * tDigest reduction. + */ + public static ReductionAggregation createTDigest(int delta) { + return new ReductionAggregation(Aggregation.createTDigest(delta)); + } + + /** + * tDigest merge reduction. + */ + public static ReductionAggregation mergeTDigest(int delta) { + return new ReductionAggregation(Aggregation.mergeTDigest(delta)); + } + + /* * Collect the values into a list. Nulls will be skipped. */ public static ReductionAggregation collectList() { diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 269c9d7eda1..9f34006043a 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4148,6 +4148,97 @@ void testMergeApproxPercentile2() { } } + @Test + void testCreateTDigestReduction() { + try (Table t1 = new Table.TestBuilder() + .column(100, 150, 160, 70, 110, 160) + .build(); + Scalar tdigest = t1.getColumn(0) + .reduce(ReductionAggregation.createTDigest(1000), DType.STRUCT)) { + assertEquals(DType.STRUCT, tdigest.getType()); + + try (CloseableArray columns = CloseableArray.wrap(tdigest.getChildrenFromStructScalar())) { + assertEquals(3, columns.size()); + try (HostColumnVector centroids = ((ColumnView) columns.get(0)).copyToHost(); + HostColumnVector min = ((ColumnView) columns.get(1)).copyToHost(); + HostColumnVector max = ((ColumnView) columns.get(2)).copyToHost()) { + assertEquals(DType.LIST, centroids.getType()); + assertEquals(DType.FLOAT64, min.getType()); + assertEquals(DType.FLOAT64, max.getType()); + assertEquals(1, min.getRowCount()); + assertEquals(1, max.getRowCount()); + assertEquals(70, min.getDouble(0)); + assertEquals(160, max.getDouble(0)); + } + } + } + } + + @Test + void testMergeTDigestReduction() { + StructType centroidStruct = new StructType(false, + new BasicType(false, DType.FLOAT64), // mean + new BasicType(false, DType.FLOAT64)); // weight + + ListType centroidList = new ListType(false, centroidStruct); + + StructType tdigestType = new StructType(false, + centroidList, + new BasicType(false, DType.FLOAT64), // min + new BasicType(false, DType.FLOAT64)); // max + + try (ColumnVector tdigests = ColumnVector.fromStructs(tdigestType, + new StructData(Arrays.asList( + new StructData(1.0, 100.0), + new StructData(2.0, 50.0)), + 1.0, // min + 2.0), // max + new StructData(Arrays.asList( + new StructData(3.0, 200.0), + new StructData(4.0, 99.0)), + 3.0, // min + 4.0)); // max + Scalar merged = tdigests.reduce(ReductionAggregation.mergeTDigest(1000), DType.STRUCT)) { + + assertEquals(DType.STRUCT, merged.getType()); + try (CloseableArray columns = CloseableArray.wrap(merged.getChildrenFromStructScalar())) { + assertEquals(3, columns.size()); + try (HostColumnVector centroids = ((ColumnView) columns.get(0)).copyToHost(); + HostColumnVector min = ((ColumnView) columns.get(1)).copyToHost(); + HostColumnVector max = ((ColumnView) columns.get(2)).copyToHost()) { + assertEquals(3, columns.size()); + assertEquals(DType.LIST, centroids.getType()); + assertEquals(DType.FLOAT64, min.getType()); + assertEquals(DType.FLOAT64, max.getType()); + assertEquals(1, min.getRowCount()); + assertEquals(1, max.getRowCount()); + assertEquals(1.0, min.getDouble(0)); + assertEquals(4.0, max.getDouble(0)); + assertEquals(1, centroids.rows); + + List list = centroids.getList(0); + assertEquals(4, list.size()); + + StructData data = (StructData) list.get(0); + assertEquals(1.0, data.dataRecord.get(0)); + assertEquals(100.0, data.dataRecord.get(1)); + + data = (StructData) list.get(1); + assertEquals(2.0, data.dataRecord.get(0)); + assertEquals(50.0, data.dataRecord.get(1)); + + data = (StructData) list.get(2); + assertEquals(3.0, data.dataRecord.get(0)); + assertEquals(200.0, data.dataRecord.get(1)); + + data = (StructData) list.get(3); + assertEquals(4.0, data.dataRecord.get(0)); + assertEquals(99.0, data.dataRecord.get(1)); + } + } + } + } + @Test void testGroupByMinMaxDecimal() { try (Table t1 = new Table.TestBuilder()