Skip to content

Commit

Permalink
JNI support for segmented reduce (#10413)
Browse files Browse the repository at this point in the history
This adds in JNI support for #9621. It also adds in a helper API to allow us to do the processing on a list easily.

Authors:
  - Robert (Bobby) Evans (https://github.com/revans2)

Approvers:
  - Liangcai Li (https://github.com/firestarman)
  - Jason Lowe (https://github.com/jlowe)

URL: #10413
  • Loading branch information
revans2 authored Mar 14, 2022
1 parent 0be0b00 commit a066e7f
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 0 deletions.
98 changes: 98 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ public final ColumnView getChildColumnView(int childIndex) {
return new ColumnView(childColumnView);
}

/**
* Get a ColumnView that is the offsets for this list.
*/
ColumnView getListOffsetsView() {
assert(getType().equals(DType.LIST));
return new ColumnView(getListOffsetCvPointer(viewHandle));
}

/**
* Gets the data buffer for the current column view (viewHandle).
* If the type is LIST, STRUCT it returns null.
Expand Down Expand Up @@ -1496,6 +1504,91 @@ public Scalar reduce(ReductionAggregation aggregation, DType outType) {
}
}

/**
* Do a segmented reduce where the offsets column indicates which groups in this to combine. The
* output type is the same as the input type.
* @param offsets an INT32 column with no nulls.
* @param aggregation the aggregation to do
* @return the result.
*/
public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation) {
return segmentedReduce(offsets, aggregation, NullPolicy.EXCLUDE, type);
}

/**
* Do a segmented reduce where the offsets column indicates which groups in this to combine.
* @param offsets an INT32 column with no nulls.
* @param aggregation the aggregation to do
* @param outType the output data type.
* @return the result.
*/
public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation,
DType outType) {
return segmentedReduce(offsets, aggregation, NullPolicy.EXCLUDE, outType);
}

/**
* Do a segmented reduce where the offsets column indicates which groups in this to combine.
* @param offsets an INT32 column with no nulls.
* @param aggregation the aggregation to do
* @param nullPolicy the null policy.
* @param outType the output data type.
* @return the result.
*/
public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation,
NullPolicy nullPolicy, DType outType) {
long nativeId = aggregation.createNativeInstance();
try {
return new ColumnVector(segmentedReduce(getNativeView(), offsets.getNativeView(), nativeId,
nullPolicy.includeNulls, outType.typeId.getNativeId(), outType.getScale()));
} finally {
Aggregation.close(nativeId);
}
}

/**
* Do a reduction on the values in a list. The output type will be the type of the data column
* of this list.
* @param aggregation the aggregation to perform
*/
public ColumnVector listReduce(SegmentedReductionAggregation aggregation) {
if (!getType().equals(DType.LIST)) {
throw new IllegalArgumentException("listReduce only works on list types");
}
try (ColumnView offsets = getListOffsetsView();
ColumnView data = getChildColumnView(0)) {
return data.segmentedReduce(offsets, aggregation);
}
}

/**
* Do a reduction on the values in a list.
* @param aggregation the aggregation to perform
* @param outType the type of the output. Typically, this should match with the child type
* of the list.
*/
public ColumnVector listReduce(SegmentedReductionAggregation aggregation, DType outType) {
return listReduce(aggregation, NullPolicy.EXCLUDE, outType);
}

/**
* Do a reduction on the values in a list.
* @param aggregation the aggregation to perform
* @param nullPolicy should nulls be included or excluded from the aggregation.
* @param outType the type of the output. Typically, this should match with the child type
* of the list.
*/
public ColumnVector listReduce(SegmentedReductionAggregation aggregation, NullPolicy nullPolicy,
DType outType) {
if (!getType().equals(DType.LIST)) {
throw new IllegalArgumentException("listReduce only works on list types");
}
try (ColumnView offsets = getListOffsetsView();
ColumnView data = getChildColumnView(0)) {
return data.segmentedReduce(offsets, aggregation, nullPolicy, outType);
}
}

/**
* Calculate various percentiles of this ColumnVector, which must contain centroids produced by
* a t-digest aggregation.
Expand Down Expand Up @@ -3897,6 +3990,9 @@ private static native long scan(long viewHandle, long aggregation,

private static native long reduce(long viewHandle, long aggregation, int dtype, int scale) throws CudfException;

private static native long segmentedReduce(long dataViewHandle, long offsetsViewHandle,
long aggregation, boolean includeNulls, int dtype, int scale) throws CudfException;

private static native long isNullNative(long viewHandle);

private static native long isNanNative(long viewHandle);
Expand Down Expand Up @@ -4024,6 +4120,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS

static native long getChildCvPointer(long viewHandle, int childIndex) throws CudfException;

private static native long getListOffsetCvPointer(long viewHandle) throws CudfException;

static native int getNativeNumChildren(long viewHandle) throws CudfException;

// calculate the amount of device memory used by this column including any child columns
Expand Down
104 changes: 104 additions & 0 deletions java/src/main/java/ai/rapids/cudf/SegmentedReductionAggregation.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* An aggregation that can be used for a reduce.
*/
public final class SegmentedReductionAggregation {
private final Aggregation wrapped;

private SegmentedReductionAggregation(Aggregation wrapped) {
this.wrapped = wrapped;
}

long createNativeInstance() {
return wrapped.createNativeInstance();
}

long getDefaultOutput() {
return wrapped.getDefaultOutput();
}

Aggregation getWrapped() {
return wrapped;
}

@Override
public int hashCode() {
return wrapped.hashCode();
}

@Override
public boolean equals(Object other) {
if (other == this) {
return true;
} else if (other instanceof SegmentedReductionAggregation) {
SegmentedReductionAggregation o = (SegmentedReductionAggregation) other;
return wrapped.equals(o.wrapped);
}
return false;
}

/**
* Sum Aggregation
*/
public static SegmentedReductionAggregation sum() {
return new SegmentedReductionAggregation(Aggregation.sum());
}

/**
* Product Aggregation.
*/
public static SegmentedReductionAggregation product() {
return new SegmentedReductionAggregation(Aggregation.product());
}

/**
* Min Aggregation
*/
public static SegmentedReductionAggregation min() {
return new SegmentedReductionAggregation(Aggregation.min());
}

/**
* Max Aggregation
*/
public static SegmentedReductionAggregation max() {
return new SegmentedReductionAggregation(Aggregation.max());
}

/**
* Any reduction. Produces a true or 1, depending on the output type,
* if any of the elements in the range are true or non-zero, otherwise produces a false or 0.
* Null values are skipped.
*/
public static SegmentedReductionAggregation any() {
return new SegmentedReductionAggregation(Aggregation.any());
}

/**
* All reduction. Produces true or 1, depending on the output type, if all of the elements in
* the range are true or non-zero, otherwise produces a false or 0.
* Null values are skipped.
*/
public static SegmentedReductionAggregation all() {
return new SegmentedReductionAggregation(Aggregation.all());
}
}
35 changes: 35 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reduce(JNIEnv *env, jclas
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_segmentedReduce(
JNIEnv *env, jclass, jlong j_data_view, jlong j_offsets_view, jlong j_agg,
jboolean include_nulls, jint j_dtype, jint scale) {
JNI_NULL_CHECK(env, j_data_view, "data column view is null", 0);
JNI_NULL_CHECK(env, j_offsets_view, "offsets column view is null", 0);
JNI_NULL_CHECK(env, j_agg, "aggregation is null", 0);
try {
cudf::jni::auto_set_device(env);
auto data = reinterpret_cast<cudf::column_view *>(j_data_view);
auto offsets = reinterpret_cast<cudf::column_view *>(j_offsets_view);
auto agg = reinterpret_cast<cudf::aggregation *>(j_agg);
auto s_agg = dynamic_cast<cudf::segmented_reduce_aggregation *>(agg);
JNI_ARG_CHECK(env, s_agg != nullptr, "agg is not a cudf::segmented_reduce_aggregation", 0)
auto null_policy = include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE;
cudf::data_type out_dtype = cudf::jni::make_data_type(j_dtype, scale);
return release_as_jlong(
cudf::segmented_reduce(*data, *offsets, *s_agg, out_dtype, null_policy));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_scan(JNIEnv *env, jclass, jlong j_col_view,
jlong j_agg, jboolean is_inclusive,
jboolean include_nulls) {
Expand Down Expand Up @@ -1783,6 +1804,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getChildCvPointer(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getListOffsetCvPointer(JNIEnv *env,
jobject j_object,
jlong handle) {
JNI_NULL_CHECK(env, handle, "native handle is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *column = reinterpret_cast<cudf::column_view *>(handle);
cudf::lists_column_view view = cudf::lists_column_view(*column);
cudf::column_view offsets_view = view.offsets();
return ptr_as_jlong(new cudf::column_view(offsets_view));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeOffsetsAddress(JNIEnv *env, jclass,
jlong handle) {
JNI_NULL_CHECK(env, handle, "native handle is null", 0);
Expand Down
115 changes: 115 additions & 0 deletions java/src/test/java/ai/rapids/cudf/SegmentedReductionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package ai.rapids.cudf;

import org.junit.jupiter.api.Test;

import java.util.Arrays;

class SegmentedReductionTest extends CudfTestBase {

@Test
public void testListSum() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(1, 2, 3),
Arrays.asList(2, 3, 4),
null,
Arrays.asList(null, 1, 2));
ColumnVector excludeExpected = ColumnVector.fromBoxedInts(6, 9, null, 3);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.sum(), NullPolicy.EXCLUDE, DType.INT32);
ColumnVector includeExpected = ColumnVector.fromBoxedInts(6, 9, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.sum(), NullPolicy.INCLUDE, DType.INT32)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListMin() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(1, 2, 3),
Arrays.asList(2, 3, 4),
null,
Arrays.asList(null, 1, 2));
ColumnVector excludeExpected = ColumnVector.fromBoxedInts(1, 2, null, 1);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.min(), NullPolicy.EXCLUDE, DType.INT32);
ColumnVector includeExpected = ColumnVector.fromBoxedInts(1, 2, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.min(), NullPolicy.INCLUDE, DType.INT32)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListMax() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(1, 2, 3),
Arrays.asList(2, 3, 4),
null,
Arrays.asList(null, 1, 2));
ColumnVector excludeExpected = ColumnVector.fromBoxedInts(3, 4, null, 2);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.max(), NullPolicy.EXCLUDE, DType.INT32);
ColumnVector includeExpected = ColumnVector.fromBoxedInts(3, 4, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.max(), NullPolicy.INCLUDE, DType.INT32)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListAny() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.BOOL8));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(true, false, false),
Arrays.asList(false, false, false),
null,
Arrays.asList(null, true, false));
ColumnVector excludeExpected = ColumnVector.fromBoxedBooleans(true, false, null, true);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.any(), NullPolicy.EXCLUDE, DType.BOOL8);
ColumnVector includeExpected = ColumnVector.fromBoxedBooleans(true, false, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.any(), NullPolicy.INCLUDE, DType.BOOL8)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListAll() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.BOOL8));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(true, true, true),
Arrays.asList(false, true, false),
null,
Arrays.asList(null, true, true));
ColumnVector excludeExpected = ColumnVector.fromBoxedBooleans(true, false, null, true);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.all(), NullPolicy.EXCLUDE, DType.BOOL8);
ColumnVector includeExpected = ColumnVector.fromBoxedBooleans(true, false, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.all(), NullPolicy.INCLUDE, DType.BOOL8)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}
}

0 comments on commit a066e7f

Please sign in to comment.