Skip to content

Commit

Permalink
Support MERGE_LISTS and MERGE_SETS in Java package (#8516)
Browse files Browse the repository at this point in the history
Closes #8445

This PR is to provide Java bindings for #8436.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #8516
  • Loading branch information
sperlingxx authored Jun 24, 2021
1 parent 086be4a commit a73d3b3
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 9 deletions.
83 changes: 78 additions & 5 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ enum Kind {
ROW_NUMBER(17),
COLLECT_LIST(18),
COLLECT_SET(19),
LEAD(20),
LAG(21),
PTX(22),
CUDA(23);
MERGE_LISTS(20),
MERGE_SETS(21),
LEAD(22),
LAG(23),
PTX(24),
CUDA(25);

final int nativeId;

Expand Down Expand Up @@ -342,6 +344,40 @@ public boolean equals(Object other) {
}
}

public static final class MergeSetsAggregation extends Aggregation {
private final NullEquality nullEquality;
private final NaNEquality nanEquality;

private MergeSetsAggregation(NullEquality nullEquality, NaNEquality nanEquality) {
super(Kind.MERGE_SETS);
this.nullEquality = nullEquality;
this.nanEquality = nanEquality;
}

@Override
long createNativeInstance() {
return Aggregation.createMergeSetsAgg(nullEquality.nullsEqual, nanEquality.nansEqual);
}

@Override
public int hashCode() {
return 31 * kind.hashCode()
+ Boolean.hashCode(nullEquality.nullsEqual)
+ Boolean.hashCode(nanEquality.nansEqual);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof MergeSetsAggregation) {
MergeSetsAggregation o = (MergeSetsAggregation) other;
return o.nullEquality == this.nullEquality && o.nanEquality == this.nanEquality;
}
return false;
}
}

protected final Kind kind;

protected Aggregation(Kind kind) {
Expand Down Expand Up @@ -713,7 +749,7 @@ public static CollectListAggregation collectList(NullPolicy nullPolicy) {
* unique instances.
*/
public static CollectSetAggregation collectSet() {
return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
return collectSet(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
}

/**
Expand All @@ -727,6 +763,38 @@ public static CollectSetAggregation collectSet(NullPolicy nullPolicy, NullEquali
return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality);
}

public static final class MergeListsAggregation extends NoParamAggregation {
private MergeListsAggregation() {
super(Kind.MERGE_LISTS);
}
}

/**
* Merge the partial lists produced by multiple CollectListAggregations.
* NOTICE: The partial lists to be merged should NOT include any null list element (but can include null list entries).
*/
public static MergeListsAggregation mergeLists() {
return new MergeListsAggregation();
}

/**
* Merge the partial sets produced by multiple CollectSetAggregations. Each null/nan value will be regarded as
* a unique instance.
*/
public static MergeSetsAggregation mergeSets() {
return mergeSets(NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
}

/**
* Merge the partial sets produced by multiple CollectSetAggregations.
*
* @param nullEquality Flag to specify whether null entries within each list should be considered equal.
* @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal.
*/
public static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) {
return new MergeSetsAggregation(nullEquality, nanEquality);
}

public static class LeadAggregation extends LeadLagAggregation
implements RollingAggregation<LeadAggregation> {
private LeadAggregation(int offset, ColumnVector defaultOutput) {
Expand Down Expand Up @@ -818,4 +886,9 @@ public static LagAggregation lag(int offset, ColumnVector defaultOutput) {
* Create a collect set aggregation.
*/
private static native long createCollectSetAgg(boolean includeNulls, boolean nullsEqual, boolean nansEqual);

/**
* Create a merge sets aggregation.
*/
private static native long createMergeSetsAgg(boolean nullsEqual, boolean nansEqual);
}
30 changes: 26 additions & 4 deletions java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv
break;
// case 18: COLLECT_LIST
// case 19: COLLECT_SET
// case 20: LEAD
// case 21: LAG
// case 22: PTX
// case 23: CUDA
// case 20: MERGE_LISTS
case 20:
ret = cudf::make_merge_lists_aggregation();
break;
// case 21: MERGE_SETS
// case 22: LEAD
// case 23: LAG
// case 24: PTX
// case 25: CUDA
default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
}

Expand Down Expand Up @@ -234,4 +239,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectSetAgg(JNIE
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMergeSetsAgg(JNIEnv *env,
jclass class_object,
jboolean nulls_equal,
jboolean nans_equal) {
try {
cudf::jni::auto_set_device(env);
cudf::null_equality null_equality =
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;
cudf::nan_equality nan_equality =
nans_equal ? cudf::nan_equality::ALL_EQUAL : cudf::nan_equality::UNEQUAL;
std::unique_ptr<cudf::aggregation> ret = cudf::make_merge_sets_aggregation(null_equality,
nan_equality);
return reinterpret_cast<jlong>(ret.release());
}
CATCH_STD(env, 0);
}

} // extern "C"
97 changes: 97 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5103,6 +5103,54 @@ void testGroupByCollectListIncludeNulls() {
}
}

@Test
void testGroupByMergeLists() {
ListType listOfInts = new ListType(false, new BasicType(false, DType.INT32));
ListType listOfStructs = new ListType(false, new StructType(false,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)));
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2), Arrays.asList(3), Arrays.asList(7, 8), Arrays.asList(4, 5, 6),
Arrays.asList(8, 9), Arrays.asList(8, 9, 10), Arrays.asList(10, 11), Arrays.asList(11, 12),
Arrays.asList(13, 13), Arrays.asList(14, 15, 15))
.column(listOfStructs,
Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2")),
Arrays.asList(new StructData(2, "s3"), new StructData(3, "s4")),
Arrays.asList(new StructData(2, "s2")),
Arrays.asList(),
Arrays.asList(new StructData(11, "s11")),
Arrays.asList(new StructData(22, "s22"), new StructData(33, "s33")),
Arrays.asList(),
Arrays.asList(new StructData(22, "s22"), new StructData(33, "s33"), new StructData(44, "s44")),
Arrays.asList(new StructData(333, "s333"), new StructData(222, "s222"), new StructData(111, "s111")),
Arrays.asList(new StructData(222, "s222"), new StructData(444, "s444")))
.build();
Table expectedListOfInts = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2, 3, 7 ,8, 4, 5, 6),
Arrays.asList(8, 9, 8, 9, 10, 10, 11, 11, 12),
Arrays.asList(13, 13),
Arrays.asList(14, 15, 15))
.build();
Table expectedListOfStructs = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfStructs,
Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2"),
new StructData(2, "s3"), new StructData(3, "s4"), new StructData(2, "s2")),
Arrays.asList(new StructData(11, "s11"), new StructData(22, "s22"), new StructData(33, "s33"),
new StructData(22, "s22"), new StructData(33, "s33"), new StructData(44, "s44")),
Arrays.asList(new StructData(333, "s333"), new StructData(222, "s222"), new StructData(111, "s111")),
Arrays.asList(new StructData(222, "s222"), new StructData(444, "s444")))
.build();
Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(1));
Table retListOfStructs = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(2))) {
assertTablesAreEqual(expectedListOfInts, retListOfInts);
assertTablesAreEqual(expectedListOfStructs, retListOfStructs);
}
}

@Test
void testGroupByCollectSetIncludeNulls() {
// test with null unequal and nan unequal
Expand Down Expand Up @@ -5165,6 +5213,55 @@ void testGroupByCollectSetIncludeNulls() {
}
}

@Test
void testGroupByMergeSets() {
ListType listOfInts = new ListType(false, new BasicType(false, DType.INT32));
ListType listOfDoubles = new ListType(false, new BasicType(false, DType.FLOAT64));
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2), Arrays.asList(3), Arrays.asList(7, 8), Arrays.asList(4, 5, 6),
Arrays.asList(8, 9), Arrays.asList(8, 9, 10), Arrays.asList(10, 11), Arrays.asList(11, 12),
Arrays.asList(13, 13), Arrays.asList(14, 15, 15))
.column(listOfDoubles,
Arrays.asList(Double.NaN, 1.2), Arrays.asList(), Arrays.asList(Double.NaN), Arrays.asList(-3e10),
Arrays.asList(1.1, 2.2, 3.3), Arrays.asList(3.3, 2.2), Arrays.asList(), Arrays.asList(),
Arrays.asList(1e3, Double.NaN, 1e-3, Double.NaN), Arrays.asList())
.build();
Table expectedListOfInts = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8),
Arrays.asList(8, 9, 10, 11, 12),
Arrays.asList(13),
Arrays.asList(14, 15))
.build();
Table expectedListOfDoubles = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfDoubles,
Arrays.asList(-3e10, 1.2, Double.NaN, Double.NaN),
Arrays.asList(1.1, 2.2, 3.3),
Arrays.asList(1e-3, 1e3, Double.NaN, Double.NaN),
Arrays.asList())
.build();
Table expectedListOfDoublesNaNEq = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfDoubles,
Arrays.asList(-3e10, 1.2, Double.NaN),
Arrays.asList(1.1, 2.2, 3.3),
Arrays.asList(1e-3, 1e3, Double.NaN),
Arrays.asList())
.build();
Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(1));
Table retListOfDoubles = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(2));
Table retListOfDoublesNaNEq = input.groupBy(0).aggregate(
Aggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2))) {
assertTablesAreEqual(expectedListOfInts, retListOfInts);
assertTablesAreEqual(expectedListOfDoubles, retListOfDoubles);
assertTablesAreEqual(expectedListOfDoublesNaNEq, retListOfDoublesNaNEq);
}
}

@Test
void testRowBitCount() {
try (Table t = new Table.TestBuilder()
Expand Down

0 comments on commit a73d3b3

Please sign in to comment.