Skip to content

Commit

Permalink
JNI: Support skipping nulls for collect aggregation (rapidsai#7457)
Browse files Browse the repository at this point in the history
This PR is to support skipping nulls for `collect ` aggregation in JVM by creating a new class `CollectAggregation` who accepts a `NullPolicy ` argument indicating whether to include nulls. 

Skipping nulls has already been supported by `collect ` aggregation with rolling in native (rapidsai#7264), so this PR just exposes the feaure in JVM.

This PR also introduces `NullPolicy ` and updates the related aggregates.

Signed-off-by: firestarman <[email protected]>

Authors:
  - Liangcai Li (@firestarman)

Approvers:
  - Robert (Bobby) Evans (@revans2)
  - MithunR (@mythrocks)

URL: rapidsai#7457
  • Loading branch information
firestarman authored and hyperbolic2346 committed Mar 23, 2021
1 parent d77a393 commit 94dd756
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 28 deletions.
134 changes: 113 additions & 21 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,6 +65,18 @@ enum Kind {
Kind(int nativeId) {this.nativeId = nativeId;}
}

/*
* This is analogous to the native 'null_policy'.
*/
public enum NullPolicy {
EXCLUDE(false),
INCLUDE(true);

NullPolicy(boolean includeNulls) { this.includeNulls = includeNulls; }

final boolean includeNulls;
}

/**
* An Aggregation that only needs a kind and nothing else.
*/
Expand Down Expand Up @@ -97,22 +109,22 @@ public boolean equals(Object other) {

private static final class NthAggregation extends Aggregation {
private final int offset;
private final boolean includeNulls;
private final NullPolicy nullPolicy;

public NthAggregation(int offset, boolean includeNulls) {
public NthAggregation(int offset, NullPolicy nullPolicy) {
super(Kind.NTH_ELEMENT);
this.offset = offset;
this.includeNulls = includeNulls;
this.nullPolicy = nullPolicy;
}

@Override
long createNativeInstance() {
return Aggregation.createNthAgg(offset, includeNulls);
return Aggregation.createNthAgg(offset, nullPolicy.includeNulls);
}

@Override
public int hashCode() {
return 31 * offset + Boolean.hashCode(includeNulls);
return 31 * offset + nullPolicy.hashCode();
}

@Override
Expand All @@ -121,7 +133,7 @@ public boolean equals(Object other) {
return true;
} else if (other instanceof NthAggregation) {
NthAggregation o = (NthAggregation) other;
return o.offset == this.offset && o.includeNulls == this.includeNulls;
return o.offset == this.offset && o.nullPolicy == this.nullPolicy;
}
return false;
}
Expand Down Expand Up @@ -158,21 +170,21 @@ public boolean equals(Object other) {
}

private static final class CountLikeAggregation extends Aggregation {
private final boolean includeNulls;
private final NullPolicy nullPolicy;

public CountLikeAggregation(Kind kind, boolean includeNulls) {
public CountLikeAggregation(Kind kind, NullPolicy nullPolicy) {
super(kind);
this.includeNulls = includeNulls;
this.nullPolicy = nullPolicy;
}

@Override
long createNativeInstance() {
return Aggregation.createCountLikeAgg(kind.nativeId, includeNulls);
return Aggregation.createCountLikeAgg(kind.nativeId, nullPolicy.includeNulls);
}

@Override
public int hashCode() {
return 31 * kind.hashCode() + Boolean.hashCode(includeNulls);
return 31 * kind.hashCode() + nullPolicy.hashCode();
}

@Override
Expand All @@ -181,7 +193,7 @@ public boolean equals(Object other) {
return true;
} else if (other instanceof CountLikeAggregation) {
CountLikeAggregation o = (CountLikeAggregation) other;
return o.includeNulls == this.includeNulls;
return o.nullPolicy == this.nullPolicy;
}
return false;
}
Expand Down Expand Up @@ -268,6 +280,36 @@ long getDefaultOutput() {
}
}

private static final class CollectAggregation extends Aggregation {
private final NullPolicy nullPolicy;

public CollectAggregation(NullPolicy nullPolicy) {
super(Kind.COLLECT);
this.nullPolicy = nullPolicy;
}

@Override
long createNativeInstance() {
return Aggregation.createCollectAgg(nullPolicy.includeNulls);
}

@Override
public int hashCode() {
return 31 * kind.hashCode() + nullPolicy.hashCode();
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof CollectAggregation) {
CollectAggregation o = (CollectAggregation) other;
return o.nullPolicy == this.nullPolicy;
}
return false;
}
}

protected final Kind kind;

protected Aggregation(Kind kind) {
Expand Down Expand Up @@ -351,16 +393,27 @@ public static Aggregation max() {
* Count number of valid, a.k.a. non-null, elements.
*/
public static Aggregation count() {
return count(false);
return count(NullPolicy.EXCLUDE);
}

/**
* Count number of elements.
* (This is deprecated, use {@link Aggregation#count(NullPolicy nullPolicy)} instead)
* @param includeNulls true if nulls should be counted. false if only non-null values should be
* counted.
*/
@Deprecated
public static Aggregation count(boolean includeNulls) {
return new CountLikeAggregation(Kind.COUNT, includeNulls);
return count(includeNulls ? NullPolicy.INCLUDE : NullPolicy.EXCLUDE);
}

/**
* Count number of elements.
* @param nullPolicy INCLUDE if nulls should be counted. EXCLUDE if only non-null values
* should be counted.
*/
public static Aggregation count(NullPolicy nullPolicy) {
return new CountLikeAggregation(Kind.COUNT, nullPolicy);
}

/**
Expand Down Expand Up @@ -473,17 +526,29 @@ public static Aggregation argMin() {
* Number of unique, non-null, elements.
*/
public static Aggregation nunique() {
return nunique(false);
return nunique(NullPolicy.EXCLUDE);
}

/**
* Number of unique elements.
* (This is deprecated, use {@link Aggregation#nunique(NullPolicy nullPolicy)} instead)
* @param includeNulls true if nulls should be counted else false. If nulls are counted they
* compare as equal so multiple null values in a range would all only
* increase the count by 1.
*/
@Deprecated
public static Aggregation nunique(boolean includeNulls) {
return new CountLikeAggregation(Kind.NUNIQUE, includeNulls);
return nunique(includeNulls ? NullPolicy.INCLUDE : NullPolicy.EXCLUDE);
}

/**
* Number of unique elements.
* @param nullPolicy INCLUDE if nulls should be counted else EXCLUDE. If nulls are counted they
* compare as equal so multiple null values in a range would all only
* increase the count by 1.
*/
public static Aggregation nunique(NullPolicy nullPolicy) {
return new CountLikeAggregation(Kind.NUNIQUE, nullPolicy);
}

/**
Expand All @@ -492,18 +557,31 @@ public static Aggregation nunique(boolean includeNulls) {
* value outside of the group range results in a null.
*/
public static Aggregation nth(int offset) {
return nth(offset, true);
return nth(offset, NullPolicy.INCLUDE);
}

/**
* Get the nth element in a group.
* (This is deprecated, use {@link Aggregation#nth(int offset, NullPolicy nullPolicy)} instead)
* @param offset the offset to look at. Negative numbers go from the end of the group. Any
* value outside of the group range results in a null.
* @param includeNulls true if nulls should be included in the aggregation or false if they
* should be skipped.
*/
@Deprecated
public static Aggregation nth(int offset, boolean includeNulls) {
return new NthAggregation(offset, includeNulls);
return nth(offset, includeNulls ? NullPolicy.INCLUDE : NullPolicy.EXCLUDE);
}

/**
* Get the nth element in a group.
* @param offset the offset to look at. Negative numbers go from the end of the group. Any
* value outside of the group range results in a null.
* @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they
* should be skipped.
*/
public static Aggregation nth(int offset, NullPolicy nullPolicy) {
return new NthAggregation(offset, nullPolicy);
}

/**
Expand All @@ -514,10 +592,19 @@ public static Aggregation rowNumber() {
}

/**
* Collect the values into a list.
* Collect the values into a list. nulls will be skipped.
*/
public static Aggregation collect() {
return new NoParamAggregation(Kind.COLLECT);
return collect(NullPolicy.EXCLUDE);
}

/**
* Collect the values into a list.
* @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they
* should be skipped.
*/
public static Aggregation collect(NullPolicy nullPolicy) {
return new CollectAggregation(nullPolicy);
}

/**
Expand Down Expand Up @@ -586,4 +673,9 @@ public static Aggregation lag(int offset, ColumnVector defaultOutput) {
* Create a lead or lag aggregation.
*/
private static native long createLeadLagAgg(int kind, int offset);

/**
* Create a collect aggregation including nulls or not.
*/
private static native long createCollectAgg(boolean includeNulls);
}
19 changes: 15 additions & 4 deletions java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,9 +81,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv
case 17: // ROW_NUMBER
ret = cudf::make_row_number_aggregation();
break;
case 18: // COLLECT
ret = cudf::make_collect_aggregation();
break;
// case 18: COLLECT
// case 19: LEAD
// case 20: LAG
// case 21: PTX
Expand Down Expand Up @@ -201,4 +199,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectAgg(JNIEnv *env,
jclass class_object,
jboolean include_nulls) {
try {
cudf::jni::auto_set_device(env);
cudf::null_policy policy =
include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE;
std::unique_ptr<cudf::aggregation> ret = cudf::make_collect_aggregation(policy);
return reinterpret_cast<jlong>(ret.release());
}
CATCH_STD(env, 0);
}

} // extern "C"
19 changes: 16 additions & 3 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2503,6 +2503,7 @@ void testWindowingRowNumber() {

@Test
void testWindowingCollect() {
Aggregation aggCollectWithNulls = Aggregation.collect(Aggregation.NullPolicy.INCLUDE);
Aggregation aggCollect = Aggregation.collect();
WindowOptions winOpts = WindowOptions.builder()
.minPeriods(1)
Expand All @@ -2513,26 +2514,38 @@ void testWindowingCollect() {
.column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key
.column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key
.column( 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8) // OBY Key
.column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column of INT32
.column( 7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null) // Agg Column of INT32
.column(nestedType, // Agg Column of Struct
new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3"),
new StructData(4, "s4"), new StructData(11, "s11"), new StructData(22, "s22"),
new StructData(33, "s33"), new StructData(44, "s44"), new StructData(111, "s111"),
new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444")
).build();
ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) {
ColumnVector expectSortedAggColumn = ColumnVector
.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null)) {
try (Table sorted = raw.orderBy(Table.asc(0), Table.asc(1), Table.asc(2))) {
ColumnVector sortedAggColumn = sorted.getColumn(3);
assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn);

// Primitive type: INT32
// a) including nulls
try (Table windowAggResults = sorted.groupBy(0, 1)
.aggregateWindows(aggCollectWithNulls.onColumn(3).overWindow(winOpts));
ColumnVector expected = ColumnVector.fromLists(
new ListType(false, new BasicType(false, DType.INT32)),
Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9),
Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2),
Arrays.asList(null,0), Arrays.asList(null,0,6), Arrays.asList(0,6,null), Arrays.asList(6,null))) {
assertColumnsAreEqual(expected, windowAggResults.getColumn(0));
}
// b) excluding nulls
try (Table windowAggResults = sorted.groupBy(0, 1)
.aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts));
ColumnVector expected = ColumnVector.fromLists(
new ListType(false, new BasicType(false, DType.INT32)),
Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9),
Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2),
Arrays.asList(8,0), Arrays.asList(8,0,6), Arrays.asList(0,6,6), Arrays.asList(6,6))) {
Arrays.asList(0), Arrays.asList(0,6), Arrays.asList(0,6), Arrays.asList(6))) {
assertColumnsAreEqual(expected, windowAggResults.getColumn(0));
}

Expand Down

0 comments on commit 94dd756

Please sign in to comment.