Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI support for segmented reduce #10413

Merged
merged 2 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ public final ColumnView getChildColumnView(int childIndex) {
return new ColumnView(childColumnView);
}

/**
* Get a ColumnView that is the offsets for this list.
*/
ColumnView getListOffsetsView() {
assert(getType().equals(DType.LIST));
return new ColumnView(getListOffsetCvPointer(viewHandle));
}

/**
* Gets the data buffer for the current column view (viewHandle).
* If the type is LIST, STRUCT it returns null.
Expand Down Expand Up @@ -1496,6 +1504,91 @@ public Scalar reduce(ReductionAggregation aggregation, DType outType) {
}
}

/**
* Do a segmented reduce where the offsets column indicates which groups in this to combine. The
* output type is the same as the input type.
* @param offsets an INT32 column with no nulls.
* @param aggregation the aggregation to do
* @return the result.
*/
public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation) {
return segmentedReduce(offsets, aggregation, NullPolicy.EXCLUDE, type);
}

/**
* Do a segmented reduce where the offsets column indicates which groups in this to combine.
* @param offsets an INT32 column with no nulls.
* @param aggregation the aggregation to do
* @param outType the output data type.
* @return the result.
*/
public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation,
DType outType) {
return segmentedReduce(offsets, aggregation, NullPolicy.EXCLUDE, outType);
}

/**
* Do a segmented reduce where the offsets column indicates which groups in this to combine.
* @param offsets an INT32 column with no nulls.
* @param aggregation the aggregation to do
* @param nullPolicy the null policy.
* @param outType the output data type.
* @return the result.
*/
public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation,
NullPolicy nullPolicy, DType outType) {
long nativeId = aggregation.createNativeInstance();
try {
return new ColumnVector(segmentedReduce(getNativeView(), offsets.getNativeView(), nativeId,
nullPolicy.includeNulls, outType.typeId.getNativeId(), outType.getScale()));
} finally {
Aggregation.close(nativeId);
}
}

/**
* Do a reduction on the values in a list. The output type will be the type of the data column
* of this list.
* @param aggregation the aggregation to perform
*/
public ColumnVector listReduce(SegmentedReductionAggregation aggregation) {
if (!getType().equals(DType.LIST)) {
throw new IllegalArgumentException("listReduce only works on list types");
}
try (ColumnView offsets = getListOffsetsView();
ColumnView data = getChildColumnView(0)) {
return data.segmentedReduce(offsets, aggregation);
}
}

/**
* Do a reduction on the values in a list.
* @param aggregation the aggregation to perform
* @param outType the type of the output. Typically, this should match with the child type
* of the list.
*/
public ColumnVector listReduce(SegmentedReductionAggregation aggregation, DType outType) {
return listReduce(aggregation, NullPolicy.EXCLUDE, outType);
}

/**
* Do a reduction on the values in a list.
* @param aggregation the aggregation to perform
* @param nullPolicy should nulls be included or excluded from the aggregation.
* @param outType the type of the output. Typically, this should match with the child type
* of the list.
*/
public ColumnVector listReduce(SegmentedReductionAggregation aggregation, NullPolicy nullPolicy,
DType outType) {
if (!getType().equals(DType.LIST)) {
throw new IllegalArgumentException("listReduce only works on list types");
}
try (ColumnView offsets = getListOffsetsView();
ColumnView data = getChildColumnView(0)) {
return data.segmentedReduce(offsets, aggregation, nullPolicy, outType);
}
}

/**
* Calculate various percentiles of this ColumnVector, which must contain centroids produced by
* a t-digest aggregation.
Expand Down Expand Up @@ -3897,6 +3990,9 @@ private static native long scan(long viewHandle, long aggregation,

private static native long reduce(long viewHandle, long aggregation, int dtype, int scale) throws CudfException;

private static native long segmentedReduce(long dataViewHandle, long offsetsViewHandle,
long aggregation, boolean includeNulls, int dtype, int scale) throws CudfException;

private static native long isNullNative(long viewHandle);

private static native long isNanNative(long viewHandle);
Expand Down Expand Up @@ -4024,6 +4120,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS

static native long getChildCvPointer(long viewHandle, int childIndex) throws CudfException;

private static native long getListOffsetCvPointer(long viewHandle) throws CudfException;

static native int getNativeNumChildren(long viewHandle) throws CudfException;

// calculate the amount of device memory used by this column including any child columns
Expand Down
104 changes: 104 additions & 0 deletions java/src/main/java/ai/rapids/cudf/SegmentedReductionAggregation.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* An aggregation that can be used for a reduce.
*/
public final class SegmentedReductionAggregation {
private final Aggregation wrapped;

private SegmentedReductionAggregation(Aggregation wrapped) {
this.wrapped = wrapped;
}

long createNativeInstance() {
return wrapped.createNativeInstance();
}

long getDefaultOutput() {
return wrapped.getDefaultOutput();
}

Aggregation getWrapped() {
return wrapped;
}

@Override
public int hashCode() {
return wrapped.hashCode();
}

@Override
public boolean equals(Object other) {
if (other == this) {
return true;
} else if (other instanceof SegmentedReductionAggregation) {
SegmentedReductionAggregation o = (SegmentedReductionAggregation) other;
return wrapped.equals(o.wrapped);
}
return false;
}

/**
* Sum Aggregation
*/
public static SegmentedReductionAggregation sum() {
return new SegmentedReductionAggregation(Aggregation.sum());
}

/**
* Product Aggregation.
*/
public static SegmentedReductionAggregation product() {
return new SegmentedReductionAggregation(Aggregation.product());
}

/**
* Min Aggregation
*/
public static SegmentedReductionAggregation min() {
return new SegmentedReductionAggregation(Aggregation.min());
}

/**
* Max Aggregation
*/
public static SegmentedReductionAggregation max() {
return new SegmentedReductionAggregation(Aggregation.max());
}

/**
* Any reduction. Produces a true or 1, depending on the output type,
* if any of the elements in the range are true or non-zero, otherwise produces a false or 0.
* Null values are skipped.
*/
public static SegmentedReductionAggregation any() {
return new SegmentedReductionAggregation(Aggregation.any());
}

/**
* All reduction. Produces true or 1, depending on the output type, if all of the elements in
* the range are true or non-zero, otherwise produces a false or 0.
* Null values are skipped.
*/
public static SegmentedReductionAggregation all() {
return new SegmentedReductionAggregation(Aggregation.all());
}
}
35 changes: 35 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,27 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reduce(JNIEnv *env, jclas
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_segmentedReduce(
JNIEnv *env, jclass, jlong j_data_view, jlong j_offsets_view, jlong j_agg,
jboolean include_nulls, jint j_dtype, jint scale) {
JNI_NULL_CHECK(env, j_data_view, "data column view is null", 0);
JNI_NULL_CHECK(env, j_offsets_view, "offsets column view is null", 0);
JNI_NULL_CHECK(env, j_agg, "aggregation is null", 0);
try {
cudf::jni::auto_set_device(env);
auto data = reinterpret_cast<cudf::column_view *>(j_data_view);
auto offsets = reinterpret_cast<cudf::column_view *>(j_offsets_view);
auto agg = reinterpret_cast<cudf::aggregation *>(j_agg);
auto s_agg = dynamic_cast<cudf::segmented_reduce_aggregation *>(agg);
JNI_ARG_CHECK(env, s_agg != nullptr, "agg is not a cudf::segmented_reduce_aggregation", 0)
auto null_policy = include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE;
cudf::data_type out_dtype = cudf::jni::make_data_type(j_dtype, scale);
return release_as_jlong(
cudf::segmented_reduce(*data, *offsets, *s_agg, out_dtype, null_policy));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_scan(JNIEnv *env, jclass, jlong j_col_view,
jlong j_agg, jboolean is_inclusive,
jboolean include_nulls) {
Expand Down Expand Up @@ -1775,6 +1796,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getChildCvPointer(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getListOffsetCvPointer(JNIEnv *env,
jobject j_object,
jlong handle) {
JNI_NULL_CHECK(env, handle, "native handle is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *column = reinterpret_cast<cudf::column_view *>(handle);
cudf::lists_column_view view = cudf::lists_column_view(*column);
cudf::column_view offsets_view = view.offsets();
return ptr_as_jlong(new cudf::column_view(offsets_view));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeOffsetsAddress(JNIEnv *env, jclass,
jlong handle) {
JNI_NULL_CHECK(env, handle, "native handle is null", 0);
Expand Down
115 changes: 115 additions & 0 deletions java/src/test/java/ai/rapids/cudf/SegmentedReductionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package ai.rapids.cudf;

import org.junit.jupiter.api.Test;

import java.util.Arrays;

class SegmentedReductionTest extends CudfTestBase {

@Test
public void testListSum() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(1, 2, 3),
Arrays.asList(2, 3, 4),
null,
Arrays.asList(null, 1, 2));
ColumnVector excludeExpected = ColumnVector.fromBoxedInts(6, 9, null, 3);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.sum(), NullPolicy.EXCLUDE, DType.INT32);
ColumnVector includeExpected = ColumnVector.fromBoxedInts(6, 9, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.sum(), NullPolicy.INCLUDE, DType.INT32)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListMin() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(1, 2, 3),
Arrays.asList(2, 3, 4),
null,
Arrays.asList(null, 1, 2));
ColumnVector excludeExpected = ColumnVector.fromBoxedInts(1, 2, null, 1);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.min(), NullPolicy.EXCLUDE, DType.INT32);
ColumnVector includeExpected = ColumnVector.fromBoxedInts(1, 2, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.min(), NullPolicy.INCLUDE, DType.INT32)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListMax() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(1, 2, 3),
Arrays.asList(2, 3, 4),
null,
Arrays.asList(null, 1, 2));
ColumnVector excludeExpected = ColumnVector.fromBoxedInts(3, 4, null, 2);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.max(), NullPolicy.EXCLUDE, DType.INT32);
ColumnVector includeExpected = ColumnVector.fromBoxedInts(3, 4, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.max(), NullPolicy.INCLUDE, DType.INT32)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListAny() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.BOOL8));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(true, false, false),
Arrays.asList(false, false, false),
null,
Arrays.asList(null, true, false));
ColumnVector excludeExpected = ColumnVector.fromBoxedBooleans(true, false, null, true);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.any(), NullPolicy.EXCLUDE, DType.BOOL8);
ColumnVector includeExpected = ColumnVector.fromBoxedBooleans(true, false, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.any(), NullPolicy.INCLUDE, DType.BOOL8)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}

@Test
public void testListAll() {
HostColumnVector.DataType dt = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.BOOL8));
try (ColumnVector listCv = ColumnVector.fromLists(dt,
Arrays.asList(true, true, true),
Arrays.asList(false, true, false),
null,
Arrays.asList(null, true, true));
ColumnVector excludeExpected = ColumnVector.fromBoxedBooleans(true, false, null, true);
ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.all(), NullPolicy.EXCLUDE, DType.BOOL8);
ColumnVector includeExpected = ColumnVector.fromBoxedBooleans(true, false, null, null);
ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.all(), NullPolicy.INCLUDE, DType.BOOL8)) {
AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded);
AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded);
}
}
}