Skip to content

Commit

Permalink
Refactor AggregationJni to support collectSet (#8057)
Browse files Browse the repository at this point in the history
This pull request refactored AggregationJni to support `COLLECT_SET` as a kind of Aggregation (as well as `COLLECT_LIST`).

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Jason Lowe (https://github.com/jlowe)

URL: #8057
  • Loading branch information
sperlingxx authored May 5, 2021
1 parent 44f21b3 commit 4715c83
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 31 deletions.
146 changes: 128 additions & 18 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ enum Kind {
NUNIQUE(15),
NTH_ELEMENT(16),
ROW_NUMBER(17),
COLLECT(18),
LEAD(19),
LAG(20),
PTX(21),
CUDA(22);
COLLECT_LIST(18),
COLLECT_SET(19),
LEAD(20),
LAG(21),
PTX(22),
CUDA(23);

final int nativeId;

Expand All @@ -77,6 +78,30 @@ public enum NullPolicy {
final boolean includeNulls;
}

/*
* This is analogous to the native 'null_equality'.
*/
public enum NullEquality {
UNEQUAL(false),
EQUAL(true);

NullEquality(boolean nullsEqual) { this.nullsEqual = nullsEqual; }

final boolean nullsEqual;
}

/*
* This is analogous to the native 'nan_equality'.
*/
public enum NaNEquality {
UNEQUAL(false),
ALL_EQUAL(true);

NaNEquality(boolean nansEqual) { this.nansEqual = nansEqual; }

final boolean nansEqual;
}

/**
* An Aggregation that only needs a kind and nothing else.
*/
Expand Down Expand Up @@ -280,17 +305,17 @@ long getDefaultOutput() {
}
}

private static final class CollectAggregation extends Aggregation {
private static final class CollectListAggregation extends Aggregation {
private final NullPolicy nullPolicy;

public CollectAggregation(NullPolicy nullPolicy) {
super(Kind.COLLECT);
public CollectListAggregation(NullPolicy nullPolicy) {
super(Kind.COLLECT_LIST);
this.nullPolicy = nullPolicy;
}

@Override
long createNativeInstance() {
return Aggregation.createCollectAgg(nullPolicy.includeNulls);
return Aggregation.createCollectListAgg(nullPolicy.includeNulls);
}

@Override
Expand All @@ -302,14 +327,55 @@ public int hashCode() {
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof CollectAggregation) {
CollectAggregation o = (CollectAggregation) other;
} else if (other instanceof CollectListAggregation) {
CollectListAggregation o = (CollectListAggregation) other;
return o.nullPolicy == this.nullPolicy;
}
return false;
}
}

private static final class CollectSetAggregation extends Aggregation {
private final NullPolicy nullPolicy;
private final NullEquality nullEquality;
private final NaNEquality nanEquality;

public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) {
super(Kind.COLLECT_SET);
this.nullPolicy = nullPolicy;
this.nullEquality = nullEquality;
this.nanEquality = nanEquality;
}

@Override
long createNativeInstance() {
return Aggregation.createCollectSetAgg(nullPolicy.includeNulls,
nullEquality.nullsEqual,
nanEquality.nansEqual);
}

@Override
public int hashCode() {
return 31 * kind.hashCode()
+ Boolean.hashCode(nullPolicy.includeNulls)
+ Boolean.hashCode(nullEquality.nullsEqual)
+ Boolean.hashCode(nanEquality.nansEqual);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof CollectSetAggregation) {
CollectSetAggregation o = (CollectSetAggregation) other;
return o.nullPolicy == this.nullPolicy &&
o.nullEquality == this.nullEquality &&
o.nanEquality == this.nanEquality;
}
return false;
}
}

protected final Kind kind;

protected Aggregation(Kind kind) {
Expand Down Expand Up @@ -592,19 +658,58 @@ public static Aggregation rowNumber() {
}

/**
* Collect the values into a list. nulls will be skipped.
* Collect the values into a list. Nulls will be skipped.
* @deprecated please use collectList as instead.
*/
@Deprecated
public static Aggregation collect() {
return collect(NullPolicy.EXCLUDE);
return collectList();
}

/**
* Collect the values into a list.
* @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they
* should be skipped.
* @deprecated please use collectList as instead.
*
* @param nullPolicy Indicates whether to include/exclude nulls during collection.
*/
@Deprecated
public static Aggregation collect(NullPolicy nullPolicy) {
return new CollectAggregation(nullPolicy);
return collectList(nullPolicy);
}

/**
* Collect the values into a list. Nulls will be skipped.
*/
public static Aggregation collectList() {
return collectList(NullPolicy.EXCLUDE);
}

/**
* Collect the values into a list.
*
* @param nullPolicy Indicates whether to include/exclude nulls during collection.
*/
public static Aggregation collectList(NullPolicy nullPolicy) {
return new CollectListAggregation(nullPolicy);
}

/**
* Collect the values into a set. All null values will be excluded, and all nan values are regarded as
* unique instances.
*/
public static Aggregation collectSet() {
return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
}

/**
* Collect the values into a set.
*
* @param nullPolicy Indicates whether to include/exclude nulls during collection.
* @param nullEquality Flag to specify whether null entries within each list should be considered equal.
* @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal.
*/
public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) {
return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality);
}

/**
Expand Down Expand Up @@ -675,7 +780,12 @@ public static Aggregation lag(int offset, ColumnVector defaultOutput) {
private static native long createLeadLagAgg(int kind, int offset);

/**
* Create a collect aggregation including nulls or not.
* Create a collect list aggregation including nulls or not.
*/
private static native long createCollectListAgg(boolean includeNulls);

/**
* Create a collect set aggregation.
*/
private static native long createCollectAgg(boolean includeNulls);
private static native long createCollectSetAgg(boolean includeNulls, boolean nullsEqual, boolean nansEqual);
}
42 changes: 32 additions & 10 deletions java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv
case 17: // ROW_NUMBER
ret = cudf::make_row_number_aggregation();
break;
// case 18: COLLECT
// case 19: LEAD
// case 20: LAG
// case 21: PTX
// case 22: CUDA
// case 18: COLLECT_LIST
// case 19: COLLECT_SET
// case 20: LEAD
// case 21: LAG
// case 22: PTX
// case 23: CUDA
default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
}

Expand Down Expand Up @@ -186,10 +187,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv
std::unique_ptr<cudf::aggregation> ret;
// These numbers come from Aggregation.java and must stay in sync
switch (kind) {
case 19: // LEAD
case 20: // LEAD
ret = cudf::make_lead_aggregation(offset);
break;
case 20: // LAG
case 21: // LAG
ret = cudf::make_lag_aggregation(offset);
break;
default: throw std::logic_error("Unsupported Lead/Lag Aggregation Operation");
Expand All @@ -199,9 +200,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectAgg(JNIEnv *env,
jclass class_object,
jboolean include_nulls) {
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectListAgg(JNIEnv *env,
jclass class_object,
jboolean include_nulls) {
try {
cudf::jni::auto_set_device(env);
cudf::null_policy policy =
Expand All @@ -212,4 +213,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectAgg(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectSetAgg(JNIEnv *env,
jclass class_object,
jboolean include_nulls,
jboolean nulls_equal,
jboolean nans_equal) {
try {
cudf::jni::auto_set_device(env);
cudf::null_policy null_policy =
include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE;
cudf::null_equality null_equality =
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;
cudf::nan_equality nan_equality =
nans_equal ? cudf::nan_equality::ALL_EQUAL : cudf::nan_equality::UNEQUAL;
std::unique_ptr<cudf::aggregation> ret = cudf::make_collect_set_aggregation(null_policy,
null_equality,
nan_equality);
return reinterpret_cast<jlong>(ret.release());
}
CATCH_STD(env, 0);
}

} // extern "C"
88 changes: 85 additions & 3 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2945,9 +2945,9 @@ void testWindowingRowNumber() {
}

@Test
void testWindowingCollect() {
Aggregation aggCollectWithNulls = Aggregation.collect(Aggregation.NullPolicy.INCLUDE);
Aggregation aggCollect = Aggregation.collect();
void testWindowingCollectList() {
Aggregation aggCollectWithNulls = Aggregation.collectList(Aggregation.NullPolicy.INCLUDE);
Aggregation aggCollect = Aggregation.collectList();
WindowOptions winOpts = WindowOptions.builder()
.minPeriods(1)
.window(2, 1).build();
Expand Down Expand Up @@ -4403,6 +4403,88 @@ void testGroupByContiguousSplitGroups() {
}
}

@Test
void testGroupByCollectListIncludeNulls() {
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 4)
.column(null, 13, null, 12, 14, null, 15, null, null, 0)
.build();
Table expected = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(new ListType(false, new BasicType(true, DType.INT32)),
Arrays.asList(null, 13, null, 12),
Arrays.asList(14, null, 15, null),
Arrays.asList((Integer) null),
Arrays.asList(0))
.build();
Table found = input.groupBy(0).aggregate(
Aggregation.collectList(Aggregation.NullPolicy.INCLUDE).onColumn(1))) {
assertTablesAreEqual(expected, found);
}
}

@Test
void testGroupByCollectSetIncludeNulls() {
// test with null unequal and nan unequal
Aggregation collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE,
Aggregation.NullEquality.UNEQUAL, Aggregation.NaNEquality.UNEQUAL);
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4)
.column(null, 13, null, 13, 14, null, 15, null, 4, 1, 1, 4, 0, 0, 0, 0)
.build();
Table expected = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(new ListType(false, new BasicType(true, DType.INT32)),
Arrays.asList(13, null, null), Arrays.asList(14, 15, null, null),
Arrays.asList(1, 4), Arrays.asList(0))
.build();
Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) {
assertTablesAreEqual(expected, found);
}
// test with null equal and nan unequal
collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE,
Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.UNEQUAL);
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4)
.column(null, 13.0, null, 13.0,
14.1, Double.NaN, 13.9, Double.NaN,
Double.NaN, null, 1.0, null,
null, null, null, null)
.build();
Table expected = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(new ListType(false, new BasicType(true, DType.FLOAT64)),
Arrays.asList(13.0, null),
Arrays.asList(13.9, 14.1, Double.NaN, Double.NaN),
Arrays.asList(1.0, Double.NaN, null),
Arrays.asList((Integer) null))
.build();
Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) {
assertTablesAreEqual(expected, found);
}
// test with null equal and nan equal
collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE,
Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.ALL_EQUAL);
try (Table input = new Table.TestBuilder()
.column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4)
.column(null, 13.0, null, 13.0,
14.1, Double.NaN, 13.9, Double.NaN,
0.0, 0.0, 0.00, 0.0,
Double.NaN, Double.NaN, null, null)
.build();
Table expected = new Table.TestBuilder()
.column(1, 2, 3, 4)
.column(new ListType(false, new BasicType(true, DType.FLOAT64)),
Arrays.asList(13.0, null),
Arrays.asList(13.9, 14.1, Double.NaN),
Arrays.asList(0.0),
Arrays.asList(Double.NaN, (Integer) null))
.build();
Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) {
assertTablesAreEqual(expected, found);
}
}

@Test
void testRowBitCount() {
try (Table t = new Table.TestBuilder()
Expand Down

0 comments on commit 4715c83

Please sign in to comment.