Skip to content

Commit

Permalink
Merge pull request #10514 from rapidsai/branch-22.04
Browse files Browse the repository at this point in the history
[gpuCI] Forward-merge branch-22.04 to branch-22.06 [skip gpuci]
  • Loading branch information
GPUtester authored Mar 25, 2022
2 parents 17c913c + c71fe1b commit d73d91f
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
14 changes: 14 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ReductionAggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ public static ReductionAggregation nth(int offset, NullPolicy nullPolicy) {
}

/**
* tDigest reduction.
*/
public static ReductionAggregation createTDigest(int delta) {
return new ReductionAggregation(Aggregation.createTDigest(delta));
}

/**
* tDigest merge reduction.
*/
public static ReductionAggregation mergeTDigest(int delta) {
return new ReductionAggregation(Aggregation.mergeTDigest(delta));
}

/*
* Collect the values into a list. Nulls will be skipped.
*/
public static ReductionAggregation collectList() {
Expand Down
91 changes: 91 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4148,6 +4148,97 @@ void testMergeApproxPercentile2() {
}
}

@Test
void testCreateTDigestReduction() {
try (Table t1 = new Table.TestBuilder()
.column(100, 150, 160, 70, 110, 160)
.build();
Scalar tdigest = t1.getColumn(0)
.reduce(ReductionAggregation.createTDigest(1000), DType.STRUCT)) {
assertEquals(DType.STRUCT, tdigest.getType());

try (CloseableArray columns = CloseableArray.wrap(tdigest.getChildrenFromStructScalar())) {
assertEquals(3, columns.size());
try (HostColumnVector centroids = ((ColumnView) columns.get(0)).copyToHost();
HostColumnVector min = ((ColumnView) columns.get(1)).copyToHost();
HostColumnVector max = ((ColumnView) columns.get(2)).copyToHost()) {
assertEquals(DType.LIST, centroids.getType());
assertEquals(DType.FLOAT64, min.getType());
assertEquals(DType.FLOAT64, max.getType());
assertEquals(1, min.getRowCount());
assertEquals(1, max.getRowCount());
assertEquals(70, min.getDouble(0));
assertEquals(160, max.getDouble(0));
}
}
}
}

@Test
void testMergeTDigestReduction() {
StructType centroidStruct = new StructType(false,
new BasicType(false, DType.FLOAT64), // mean
new BasicType(false, DType.FLOAT64)); // weight

ListType centroidList = new ListType(false, centroidStruct);

StructType tdigestType = new StructType(false,
centroidList,
new BasicType(false, DType.FLOAT64), // min
new BasicType(false, DType.FLOAT64)); // max

try (ColumnVector tdigests = ColumnVector.fromStructs(tdigestType,
new StructData(Arrays.asList(
new StructData(1.0, 100.0),
new StructData(2.0, 50.0)),
1.0, // min
2.0), // max
new StructData(Arrays.asList(
new StructData(3.0, 200.0),
new StructData(4.0, 99.0)),
3.0, // min
4.0)); // max
Scalar merged = tdigests.reduce(ReductionAggregation.mergeTDigest(1000), DType.STRUCT)) {

assertEquals(DType.STRUCT, merged.getType());
try (CloseableArray columns = CloseableArray.wrap(merged.getChildrenFromStructScalar())) {
assertEquals(3, columns.size());
try (HostColumnVector centroids = ((ColumnView) columns.get(0)).copyToHost();
HostColumnVector min = ((ColumnView) columns.get(1)).copyToHost();
HostColumnVector max = ((ColumnView) columns.get(2)).copyToHost()) {
assertEquals(3, columns.size());
assertEquals(DType.LIST, centroids.getType());
assertEquals(DType.FLOAT64, min.getType());
assertEquals(DType.FLOAT64, max.getType());
assertEquals(1, min.getRowCount());
assertEquals(1, max.getRowCount());
assertEquals(1.0, min.getDouble(0));
assertEquals(4.0, max.getDouble(0));
assertEquals(1, centroids.rows);

List list = centroids.getList(0);
assertEquals(4, list.size());

StructData data = (StructData) list.get(0);
assertEquals(1.0, data.dataRecord.get(0));
assertEquals(100.0, data.dataRecord.get(1));

data = (StructData) list.get(1);
assertEquals(2.0, data.dataRecord.get(0));
assertEquals(50.0, data.dataRecord.get(1));

data = (StructData) list.get(2);
assertEquals(3.0, data.dataRecord.get(0));
assertEquals(200.0, data.dataRecord.get(1));

data = (StructData) list.get(3);
assertEquals(4.0, data.dataRecord.get(0));
assertEquals(99.0, data.dataRecord.get(1));
}
}
}
}

@Test
void testGroupByMinMaxDecimal() {
try (Table t1 = new Table.TestBuilder()
Expand Down

0 comments on commit d73d91f

Please sign in to comment.