Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Java bindings for t-digest reduction #10446

Merged
merged 30 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
245e68c
Add scan_aggregation and reduce_aggregations. C++ side only.
nvdbaranec Feb 22, 2022
c884d5c
Java bindings.
nvdbaranec Feb 23, 2022
321c9b2
Merge branch 'branch-22.04' into scan_reduce_aggregations
nvdbaranec Feb 24, 2022
900d55c
Python bindings.
nvdbaranec Feb 25, 2022
0398a0d
Copyright updates.
nvdbaranec Feb 25, 2022
a3a71b8
PR review comments.
nvdbaranec Mar 7, 2022
56a6c0f
Formatting
nvdbaranec Mar 7, 2022
8917445
Centralize tdigest aggregation code to quantiles/tdigest.
nvdbaranec Mar 7, 2022
e693562
Clean up some test code.
nvdbaranec Mar 9, 2022
f49e2c9
Merge branch 'scan_reduce_aggregations' into tdigest_code_move
nvdbaranec Mar 9, 2022
23cae44
Small test tweak.
nvdbaranec Mar 9, 2022
7fdc9f5
Merge branch 'scan_reduce_aggregations' into tdigest_code_move
nvdbaranec Mar 9, 2022
3088ec8
tdigest reduce_aggregation functionality and tests.
nvdbaranec Mar 10, 2022
6f940fd
Merge branch 'branch-22.04' into scan_reduce_aggregations
nvdbaranec Mar 11, 2022
13c776a
Merge branch 'scan_reduce_aggregations' into tdigest_code_move
nvdbaranec Mar 11, 2022
3140f5f
Merge branch 'tdigest_code_move' into tdigest_reduction
nvdbaranec Mar 11, 2022
27a854e
Merge branch 'branch-22.04' into tdigest_code_move
nvdbaranec Mar 11, 2022
6827e8f
Copyright update.
nvdbaranec Mar 11, 2022
25c1849
cmake format fixes.
nvdbaranec Mar 14, 2022
b86b3db
Merge branch 'tdigest_code_move' into tdigest_reduction
nvdbaranec Mar 14, 2022
83f4d31
Merge branch 'branch-22.04' into tdigest_reduction
nvdbaranec Mar 14, 2022
6a2d50e
Merge tdigest aggregation for cudf::reduce
nvdbaranec Mar 14, 2022
0fdd74e
Formatting fixes.
nvdbaranec Mar 14, 2022
1ab070c
Add JNI methods for tDigest reductions
andygrove Mar 16, 2022
af26226
Add unit test for createTDigestReduction
andygrove Mar 16, 2022
9d08477
merge from branch-22.04
andygrove Mar 23, 2022
a249d1a
Add test for merge tdigest
andygrove Mar 24, 2022
59a74f9
remove unused import
andygrove Mar 24, 2022
902866b
Merge remote-tracking branch 'rapidsai/branch-22.04' into tdigest_red…
andygrove Mar 24, 2022
6a7455b
fix resource leak
andygrove Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ReductionAggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ public static ReductionAggregation nth(int offset, NullPolicy nullPolicy) {
}

/**
* tDigest reduction.
*/
public static ReductionAggregation createTDigest(int delta) {
return new ReductionAggregation(Aggregation.createTDigest(delta));
}

/**
* tDigest merge reduction.
*/
public static ReductionAggregation mergeTDigest(int delta) {
return new ReductionAggregation(Aggregation.mergeTDigest(delta));
}

/*
* Collect the values into a list. Nulls will be skipped.
*/
public static ReductionAggregation collectList() {
Expand Down
91 changes: 91 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4148,6 +4148,97 @@ void testMergeApproxPercentile2() {
}
}

@Test
void testCreateTDigestReduction() {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
try (Table t1 = new Table.TestBuilder()
.column(100, 150, 160, 70, 110, 160)
.build();
Scalar tdigest = t1.getColumn(0)
.reduce(ReductionAggregation.createTDigest(1000), DType.STRUCT)) {
assertEquals(DType.STRUCT, tdigest.getType());

try (CloseableArray columns = CloseableArray.wrap(tdigest.getChildrenFromStructScalar())) {
assertEquals(3, columns.size());
try (HostColumnVector centroids = ((ColumnView) columns.get(0)).copyToHost();
HostColumnVector min = ((ColumnView) columns.get(1)).copyToHost();
HostColumnVector max = ((ColumnView) columns.get(2)).copyToHost()) {
assertEquals(DType.LIST, centroids.getType());
assertEquals(DType.FLOAT64, min.getType());
assertEquals(DType.FLOAT64, max.getType());
assertEquals(1, min.getRowCount());
assertEquals(1, max.getRowCount());
assertEquals(70, min.getDouble(0));
assertEquals(160, max.getDouble(0));
}
}
}
}

@Test
void testMergeTDigestReduction() {
StructType centroidStruct = new StructType(false,
new BasicType(false, DType.FLOAT64), // mean
new BasicType(false, DType.FLOAT64)); // weight

ListType centroidList = new ListType(false, centroidStruct);

StructType tdigestType = new StructType(false,
centroidList,
new BasicType(false, DType.FLOAT64), // min
new BasicType(false, DType.FLOAT64)); // max

try (ColumnVector tdigests = ColumnVector.fromStructs(tdigestType,
new StructData(Arrays.asList(
new StructData(1.0, 100.0),
new StructData(2.0, 50.0)),
1.0, // min
2.0), // max
new StructData(Arrays.asList(
new StructData(3.0, 200.0),
new StructData(4.0, 99.0)),
3.0, // min
4.0)); // max
Scalar merged = tdigests.reduce(ReductionAggregation.mergeTDigest(1000), DType.STRUCT)) {

assertEquals(DType.STRUCT, merged.getType());
try (CloseableArray columns = CloseableArray.wrap(merged.getChildrenFromStructScalar())) {
assertEquals(3, columns.size());
try (HostColumnVector centroids = ((ColumnView) columns.get(0)).copyToHost();
HostColumnVector min = ((ColumnView) columns.get(1)).copyToHost();
HostColumnVector max = ((ColumnView) columns.get(2)).copyToHost()) {
assertEquals(3, columns.size());
assertEquals(DType.LIST, centroids.getType());
assertEquals(DType.FLOAT64, min.getType());
assertEquals(DType.FLOAT64, max.getType());
assertEquals(1, min.getRowCount());
assertEquals(1, max.getRowCount());
assertEquals(1.0, min.getDouble(0));
assertEquals(4.0, max.getDouble(0));
assertEquals(1, centroids.rows);

List list = centroids.getList(0);
assertEquals(4, list.size());

StructData data = (StructData) list.get(0);
assertEquals(1.0, data.dataRecord.get(0));
assertEquals(100.0, data.dataRecord.get(1));

data = (StructData) list.get(1);
assertEquals(2.0, data.dataRecord.get(0));
assertEquals(50.0, data.dataRecord.get(1));

data = (StructData) list.get(2);
assertEquals(3.0, data.dataRecord.get(0));
assertEquals(200.0, data.dataRecord.get(1));

data = (StructData) list.get(3);
assertEquals(4.0, data.dataRecord.get(0));
assertEquals(99.0, data.dataRecord.get(1));
}
}
}
}

@Test
void testGroupByMinMaxDecimal() {
try (Table t1 = new Table.TestBuilder()
Expand Down