diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index cd826707de2..3fe244c0112 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -231,6 +231,14 @@ public final ColumnView getChildColumnView(int childIndex) { return new ColumnView(childColumnView); } + /** + * Get a ColumnView that is the offsets for this list. + */ + ColumnView getListOffsetsView() { + assert(getType().equals(DType.LIST)); + return new ColumnView(getListOffsetCvPointer(viewHandle)); + } + /** * Gets the data buffer for the current column view (viewHandle). * If the type is LIST, STRUCT it returns null. @@ -1496,6 +1504,91 @@ public Scalar reduce(ReductionAggregation aggregation, DType outType) { } } + /** + * Do a segmented reduce where the offsets column indicates which groups in this to combine. The + * output type is the same as the input type. + * @param offsets an INT32 column with no nulls. + * @param aggregation the aggregation to do + * @return the result. + */ + public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation) { + return segmentedReduce(offsets, aggregation, NullPolicy.EXCLUDE, type); + } + + /** + * Do a segmented reduce where the offsets column indicates which groups in this to combine. + * @param offsets an INT32 column with no nulls. + * @param aggregation the aggregation to do + * @param outType the output data type. + * @return the result. + */ + public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation, + DType outType) { + return segmentedReduce(offsets, aggregation, NullPolicy.EXCLUDE, outType); + } + + /** + * Do a segmented reduce where the offsets column indicates which groups in this to combine. + * @param offsets an INT32 column with no nulls. + * @param aggregation the aggregation to do + * @param nullPolicy the null policy. + * @param outType the output data type. + * @return the result. + */ + public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggregation aggregation, + NullPolicy nullPolicy, DType outType) { + long nativeId = aggregation.createNativeInstance(); + try { + return new ColumnVector(segmentedReduce(getNativeView(), offsets.getNativeView(), nativeId, + nullPolicy.includeNulls, outType.typeId.getNativeId(), outType.getScale())); + } finally { + Aggregation.close(nativeId); + } + } + + /** + * Do a reduction on the values in a list. The output type will be the type of the data column + * of this list. + * @param aggregation the aggregation to perform + */ + public ColumnVector listReduce(SegmentedReductionAggregation aggregation) { + if (!getType().equals(DType.LIST)) { + throw new IllegalArgumentException("listReduce only works on list types"); + } + try (ColumnView offsets = getListOffsetsView(); + ColumnView data = getChildColumnView(0)) { + return data.segmentedReduce(offsets, aggregation); + } + } + + /** + * Do a reduction on the values in a list. + * @param aggregation the aggregation to perform + * @param outType the type of the output. Typically, this should match with the child type + * of the list. + */ + public ColumnVector listReduce(SegmentedReductionAggregation aggregation, DType outType) { + return listReduce(aggregation, NullPolicy.EXCLUDE, outType); + } + + /** + * Do a reduction on the values in a list. + * @param aggregation the aggregation to perform + * @param nullPolicy should nulls be included or excluded from the aggregation. + * @param outType the type of the output. Typically, this should match with the child type + * of the list. + */ + public ColumnVector listReduce(SegmentedReductionAggregation aggregation, NullPolicy nullPolicy, + DType outType) { + if (!getType().equals(DType.LIST)) { + throw new IllegalArgumentException("listReduce only works on list types"); + } + try (ColumnView offsets = getListOffsetsView(); + ColumnView data = getChildColumnView(0)) { + return data.segmentedReduce(offsets, aggregation, nullPolicy, outType); + } + } + /** * Calculate various percentiles of this ColumnVector, which must contain centroids produced by * a t-digest aggregation. @@ -3897,6 +3990,9 @@ private static native long scan(long viewHandle, long aggregation, private static native long reduce(long viewHandle, long aggregation, int dtype, int scale) throws CudfException; + private static native long segmentedReduce(long dataViewHandle, long offsetsViewHandle, + long aggregation, boolean includeNulls, int dtype, int scale) throws CudfException; + private static native long isNullNative(long viewHandle); private static native long isNanNative(long viewHandle); @@ -4024,6 +4120,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS static native long getChildCvPointer(long viewHandle, int childIndex) throws CudfException; + private static native long getListOffsetCvPointer(long viewHandle) throws CudfException; + static native int getNativeNumChildren(long viewHandle) throws CudfException; // calculate the amount of device memory used by this column including any child columns diff --git a/java/src/main/java/ai/rapids/cudf/SegmentedReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/SegmentedReductionAggregation.java new file mode 100644 index 00000000000..7ed150a2fec --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/SegmentedReductionAggregation.java @@ -0,0 +1,104 @@ +/* + * + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * An aggregation that can be used for a reduce. + */ +public final class SegmentedReductionAggregation { + private final Aggregation wrapped; + + private SegmentedReductionAggregation(Aggregation wrapped) { + this.wrapped = wrapped; + } + + long createNativeInstance() { + return wrapped.createNativeInstance(); + } + + long getDefaultOutput() { + return wrapped.getDefaultOutput(); + } + + Aggregation getWrapped() { + return wrapped; + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other instanceof SegmentedReductionAggregation) { + SegmentedReductionAggregation o = (SegmentedReductionAggregation) other; + return wrapped.equals(o.wrapped); + } + return false; + } + + /** + * Sum Aggregation + */ + public static SegmentedReductionAggregation sum() { + return new SegmentedReductionAggregation(Aggregation.sum()); + } + + /** + * Product Aggregation. + */ + public static SegmentedReductionAggregation product() { + return new SegmentedReductionAggregation(Aggregation.product()); + } + + /** + * Min Aggregation + */ + public static SegmentedReductionAggregation min() { + return new SegmentedReductionAggregation(Aggregation.min()); + } + + /** + * Max Aggregation + */ + public static SegmentedReductionAggregation max() { + return new SegmentedReductionAggregation(Aggregation.max()); + } + + /** + * Any reduction. Produces a true or 1, depending on the output type, + * if any of the elements in the range are true or non-zero, otherwise produces a false or 0. + * Null values are skipped. + */ + public static SegmentedReductionAggregation any() { + return new SegmentedReductionAggregation(Aggregation.any()); + } + + /** + * All reduction. Produces true or 1, depending on the output type, if all of the elements in + * the range are true or non-zero, otherwise produces a false or 0. + * Null values are skipped. + */ + public static SegmentedReductionAggregation all() { + return new SegmentedReductionAggregation(Aggregation.all()); + } +} diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index a69c7c29900..d417c81758e 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -262,6 +262,27 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reduce(JNIEnv *env, jclas CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_segmentedReduce( + JNIEnv *env, jclass, jlong j_data_view, jlong j_offsets_view, jlong j_agg, + jboolean include_nulls, jint j_dtype, jint scale) { + JNI_NULL_CHECK(env, j_data_view, "data column view is null", 0); + JNI_NULL_CHECK(env, j_offsets_view, "offsets column view is null", 0); + JNI_NULL_CHECK(env, j_agg, "aggregation is null", 0); + try { + cudf::jni::auto_set_device(env); + auto data = reinterpret_cast(j_data_view); + auto offsets = reinterpret_cast(j_offsets_view); + auto agg = reinterpret_cast(j_agg); + auto s_agg = dynamic_cast(agg); + JNI_ARG_CHECK(env, s_agg != nullptr, "agg is not a cudf::segmented_reduce_aggregation", 0) + auto null_policy = include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE; + cudf::data_type out_dtype = cudf::jni::make_data_type(j_dtype, scale); + return release_as_jlong( + cudf::segmented_reduce(*data, *offsets, *s_agg, out_dtype, null_policy)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_scan(JNIEnv *env, jclass, jlong j_col_view, jlong j_agg, jboolean is_inclusive, jboolean include_nulls) { @@ -1775,6 +1796,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getChildCvPointer(JNIEnv CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getListOffsetCvPointer(JNIEnv *env, + jobject j_object, + jlong handle) { + JNI_NULL_CHECK(env, handle, "native handle is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *column = reinterpret_cast(handle); + cudf::lists_column_view view = cudf::lists_column_view(*column); + cudf::column_view offsets_view = view.offsets(); + return ptr_as_jlong(new cudf::column_view(offsets_view)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeOffsetsAddress(JNIEnv *env, jclass, jlong handle) { JNI_NULL_CHECK(env, handle, "native handle is null", 0); diff --git a/java/src/test/java/ai/rapids/cudf/SegmentedReductionTest.java b/java/src/test/java/ai/rapids/cudf/SegmentedReductionTest.java new file mode 100644 index 00000000000..c902ab97c52 --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/SegmentedReductionTest.java @@ -0,0 +1,115 @@ +/* + * + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package ai.rapids.cudf; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +class SegmentedReductionTest extends CudfTestBase { + + @Test + public void testListSum() { + HostColumnVector.DataType dt = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + try (ColumnVector listCv = ColumnVector.fromLists(dt, + Arrays.asList(1, 2, 3), + Arrays.asList(2, 3, 4), + null, + Arrays.asList(null, 1, 2)); + ColumnVector excludeExpected = ColumnVector.fromBoxedInts(6, 9, null, 3); + ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.sum(), NullPolicy.EXCLUDE, DType.INT32); + ColumnVector includeExpected = ColumnVector.fromBoxedInts(6, 9, null, null); + ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.sum(), NullPolicy.INCLUDE, DType.INT32)) { + AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded); + AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded); + } + } + + @Test + public void testListMin() { + HostColumnVector.DataType dt = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + try (ColumnVector listCv = ColumnVector.fromLists(dt, + Arrays.asList(1, 2, 3), + Arrays.asList(2, 3, 4), + null, + Arrays.asList(null, 1, 2)); + ColumnVector excludeExpected = ColumnVector.fromBoxedInts(1, 2, null, 1); + ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.min(), NullPolicy.EXCLUDE, DType.INT32); + ColumnVector includeExpected = ColumnVector.fromBoxedInts(1, 2, null, null); + ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.min(), NullPolicy.INCLUDE, DType.INT32)) { + AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded); + AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded); + } + } + + @Test + public void testListMax() { + HostColumnVector.DataType dt = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + try (ColumnVector listCv = ColumnVector.fromLists(dt, + Arrays.asList(1, 2, 3), + Arrays.asList(2, 3, 4), + null, + Arrays.asList(null, 1, 2)); + ColumnVector excludeExpected = ColumnVector.fromBoxedInts(3, 4, null, 2); + ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.max(), NullPolicy.EXCLUDE, DType.INT32); + ColumnVector includeExpected = ColumnVector.fromBoxedInts(3, 4, null, null); + ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.max(), NullPolicy.INCLUDE, DType.INT32)) { + AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded); + AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded); + } + } + + @Test + public void testListAny() { + HostColumnVector.DataType dt = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.BOOL8)); + try (ColumnVector listCv = ColumnVector.fromLists(dt, + Arrays.asList(true, false, false), + Arrays.asList(false, false, false), + null, + Arrays.asList(null, true, false)); + ColumnVector excludeExpected = ColumnVector.fromBoxedBooleans(true, false, null, true); + ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.any(), NullPolicy.EXCLUDE, DType.BOOL8); + ColumnVector includeExpected = ColumnVector.fromBoxedBooleans(true, false, null, null); + ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.any(), NullPolicy.INCLUDE, DType.BOOL8)) { + AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded); + AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded); + } + } + + @Test + public void testListAll() { + HostColumnVector.DataType dt = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.BOOL8)); + try (ColumnVector listCv = ColumnVector.fromLists(dt, + Arrays.asList(true, true, true), + Arrays.asList(false, true, false), + null, + Arrays.asList(null, true, true)); + ColumnVector excludeExpected = ColumnVector.fromBoxedBooleans(true, false, null, true); + ColumnVector nullExcluded = listCv.listReduce(SegmentedReductionAggregation.all(), NullPolicy.EXCLUDE, DType.BOOL8); + ColumnVector includeExpected = ColumnVector.fromBoxedBooleans(true, false, null, null); + ColumnVector nullIncluded = listCv.listReduce(SegmentedReductionAggregation.all(), NullPolicy.INCLUDE, DType.BOOL8)) { + AssertUtils.assertColumnsAreEqual(excludeExpected, nullExcluded); + AssertUtils.assertColumnsAreEqual(includeExpected, nullIncluded); + } + } +}