Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MERGE_LISTS and MERGE_SETS in Java package [skip ci] #8516

Merged
merged 3 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ enum Kind {
ROW_NUMBER(17),
COLLECT_LIST(18),
COLLECT_SET(19),
LEAD(20),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@revans2 I assume that your problem is due to this.

Copy link
Contributor

@revans2 revans2 Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No my problem is that @sperlingxx didn't even run the unit tests or he would have found this and there would have been no issue. I will post a patch shortly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we arbitrarily define the value for aggregation enums, which is different from the enums defined in libcudf. I was talking about this before (probably with Jason). We should have some better way to automate this. Otherwise, we will continue to have issues with this in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These just happen to match the values in libcudf. But in this case we never cast them so it does not have to match. This is the one place that we do it correctly. If we did cast them, then your patch that changed the values would have broken the java build, not this patch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's my mistake! I apologized for my carelessness.

LAG(21),
PTX(22),
CUDA(23);
MERGE_LISTS(20),
MERGE_SETS(21),
LEAD(22),
LAG(23),
PTX(24),
CUDA(25);

final int nativeId;

Expand Down Expand Up @@ -342,6 +344,40 @@ public boolean equals(Object other) {
}
}

public static final class MergeSetsAggregation extends Aggregation {
private final NullEquality nullEquality;
private final NaNEquality nanEquality;

private MergeSetsAggregation(NullEquality nullEquality, NaNEquality nanEquality) {
super(Kind.MERGE_SETS);
this.nullEquality = nullEquality;
this.nanEquality = nanEquality;
}

@Override
long createNativeInstance() {
return Aggregation.createMergeSetsAgg(nullEquality.nullsEqual, nanEquality.nansEqual);
}

@Override
public int hashCode() {
return 31 * kind.hashCode()
+ Boolean.hashCode(nullEquality.nullsEqual)
+ Boolean.hashCode(nanEquality.nansEqual);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof MergeSetsAggregation) {
MergeSetsAggregation o = (MergeSetsAggregation) other;
return o.nullEquality == this.nullEquality && o.nanEquality == this.nanEquality;
}
return false;
}
}

protected final Kind kind;

protected Aggregation(Kind kind) {
Expand Down Expand Up @@ -713,7 +749,7 @@ public static CollectListAggregation collectList(NullPolicy nullPolicy) {
* unique instances.
*/
public static CollectSetAggregation collectSet() {
return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
return collectSet(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
}

/**
Expand All @@ -727,6 +763,38 @@ public static CollectSetAggregation collectSet(NullPolicy nullPolicy, NullEquali
return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality);
}

public static final class MergeListsAggregation extends NoParamAggregation {
private MergeListsAggregation() {
super(Kind.MERGE_LISTS);
}
}

/**
* Merge the partial lists produced by multiple CollectListAggregations.
* NOTICE: The partial lists to be merged should NOT include any null list element (but can include null list entries).
*/
public static MergeListsAggregation mergeLists() {
return new MergeListsAggregation();
}

/**
* Merge the partial sets produced by multiple CollectSetAggregations. Each null/nan value will be regarded as
* a unique instance.
*/
public static MergeSetsAggregation mergeSets() {
return mergeSets(NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
}

/**
* Merge the partial sets produced by multiple CollectSetAggregations.
*
* @param nullEquality Flag to specify whether null entries within each list should be considered equal.
* @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal.
*/
public static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) {
return new MergeSetsAggregation(nullEquality, nanEquality);
}

public static class LeadAggregation extends LeadLagAggregation
implements RollingAggregation<LeadAggregation> {
private LeadAggregation(int offset, ColumnVector defaultOutput) {
Expand Down Expand Up @@ -818,4 +886,9 @@ public static LagAggregation lag(int offset, ColumnVector defaultOutput) {
* Create a collect set aggregation.
*/
private static native long createCollectSetAgg(boolean includeNulls, boolean nullsEqual, boolean nansEqual);

/**
* Create a merge sets aggregation.
*/
private static native long createMergeSetsAgg(boolean nullsEqual, boolean nansEqual);
}
30 changes: 26 additions & 4 deletions java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv
break;
// case 18: COLLECT_LIST
// case 19: COLLECT_SET
// case 20: LEAD
// case 21: LAG
// case 22: PTX
// case 23: CUDA
// case 20: MERGE_LISTS
case 20:
ret = cudf::make_merge_lists_aggregation();
break;
// case 21: MERGE_SETS
// case 22: LEAD
// case 23: LAG
// case 24: PTX
// case 25: CUDA
default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
}

Expand Down Expand Up @@ -234,4 +239,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectSetAgg(JNIE
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMergeSetsAgg(JNIEnv *env,
jclass class_object,
jboolean nulls_equal,
jboolean nans_equal) {
try {
cudf::jni::auto_set_device(env);
cudf::null_equality null_equality =
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;
cudf::nan_equality nan_equality =
nans_equal ? cudf::nan_equality::ALL_EQUAL : cudf::nan_equality::UNEQUAL;
std::unique_ptr<cudf::aggregation> ret = cudf::make_merge_sets_aggregation(null_equality,
nan_equality);
return reinterpret_cast<jlong>(ret.release());
}
CATCH_STD(env, 0);
}

} // extern "C"
97 changes: 97 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5103,6 +5103,54 @@ void testGroupByCollectListIncludeNulls() {
}
}

@Test
void testGroupByMergeLists() {
ListType listOfInts = new ListType(false, new BasicType(false, DType.INT32));
ListType listOfStructs = new ListType(false, new StructType(false,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)));
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2), Arrays.asList(3), Arrays.asList(7, 8), Arrays.asList(4, 5, 6),
Arrays.asList(8, 9), Arrays.asList(8, 9, 10), Arrays.asList(10, 11), Arrays.asList(11, 12),
Arrays.asList(13, 13), Arrays.asList(14, 15, 15))
.column(listOfStructs,
Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2")),
Arrays.asList(new StructData(2, "s3"), new StructData(3, "s4")),
Arrays.asList(new StructData(2, "s2")),
Arrays.asList(),
Arrays.asList(new StructData(11, "s11")),
Arrays.asList(new StructData(22, "s22"), new StructData(33, "s33")),
Arrays.asList(),
Arrays.asList(new StructData(22, "s22"), new StructData(33, "s33"), new StructData(44, "s44")),
Arrays.asList(new StructData(333, "s333"), new StructData(222, "s222"), new StructData(111, "s111")),
Arrays.asList(new StructData(222, "s222"), new StructData(444, "s444")))
.build();
Table expectedListOfInts = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2, 3, 7 ,8, 4, 5, 6),
Arrays.asList(8, 9, 8, 9, 10, 10, 11, 11, 12),
Arrays.asList(13, 13),
Arrays.asList(14, 15, 15))
.build();
Table expectedListOfStructs = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfStructs,
Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2"),
new StructData(2, "s3"), new StructData(3, "s4"), new StructData(2, "s2")),
Arrays.asList(new StructData(11, "s11"), new StructData(22, "s22"), new StructData(33, "s33"),
new StructData(22, "s22"), new StructData(33, "s33"), new StructData(44, "s44")),
Arrays.asList(new StructData(333, "s333"), new StructData(222, "s222"), new StructData(111, "s111")),
Arrays.asList(new StructData(222, "s222"), new StructData(444, "s444")))
.build();
Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(1));
Table retListOfStructs = input.groupBy(0).aggregate(Aggregation.mergeLists().onColumn(2))) {
assertTablesAreEqual(expectedListOfInts, retListOfInts);
assertTablesAreEqual(expectedListOfStructs, retListOfStructs);
}
}

@Test
void testGroupByCollectSetIncludeNulls() {
// test with null unequal and nan unequal
Expand Down Expand Up @@ -5165,6 +5213,55 @@ void testGroupByCollectSetIncludeNulls() {
}
}

@Test
void testGroupByMergeSets() {
ListType listOfInts = new ListType(false, new BasicType(false, DType.INT32));
ListType listOfDoubles = new ListType(false, new BasicType(false, DType.FLOAT64));
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2), Arrays.asList(3), Arrays.asList(7, 8), Arrays.asList(4, 5, 6),
Arrays.asList(8, 9), Arrays.asList(8, 9, 10), Arrays.asList(10, 11), Arrays.asList(11, 12),
Arrays.asList(13, 13), Arrays.asList(14, 15, 15))
.column(listOfDoubles,
Arrays.asList(Double.NaN, 1.2), Arrays.asList(), Arrays.asList(Double.NaN), Arrays.asList(-3e10),
Arrays.asList(1.1, 2.2, 3.3), Arrays.asList(3.3, 2.2), Arrays.asList(), Arrays.asList(),
Arrays.asList(1e3, Double.NaN, 1e-3, Double.NaN), Arrays.asList())
.build();
Table expectedListOfInts = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfInts,
Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8),
Arrays.asList(8, 9, 10, 11, 12),
Arrays.asList(13),
Arrays.asList(14, 15))
.build();
Table expectedListOfDoubles = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfDoubles,
Arrays.asList(-3e10, 1.2, Double.NaN, Double.NaN),
Arrays.asList(1.1, 2.2, 3.3),
Arrays.asList(1e-3, 1e3, Double.NaN, Double.NaN),
Arrays.asList())
.build();
Table expectedListOfDoublesNaNEq = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(listOfDoubles,
Arrays.asList(-3e10, 1.2, Double.NaN),
Arrays.asList(1.1, 2.2, 3.3),
Arrays.asList(1e-3, 1e3, Double.NaN),
Arrays.asList())
.build();
Table retListOfInts = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(1));
Table retListOfDoubles = input.groupBy(0).aggregate(Aggregation.mergeSets().onColumn(2));
Table retListOfDoublesNaNEq = input.groupBy(0).aggregate(
Aggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2))) {
assertTablesAreEqual(expectedListOfInts, retListOfInts);
assertTablesAreEqual(expectedListOfDoubles, retListOfDoubles);
assertTablesAreEqual(expectedListOfDoublesNaNEq, retListOfDoublesNaNEq);
}
}

@Test
void testRowBitCount() {
try (Table t = new Table.TestBuilder()
Expand Down