diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 402c64dd83d..93786ce5ee2 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1303,6 +1303,18 @@ public final ColumnVector rollingWindow(Aggregation op, WindowOptions options) { } } + /** + * Compute the cumulative sum/prefix sum of the values in this column. + * This is similar to a rolling window SUM with unbounded preceding and none following. + * Input values 1, 2, 3 + * Output values 1, 3, 6 + * This currently only works for long values that are not nullable as this is currently a + * very simple implementation. It may be expanded in the future if needed. + */ + public final ColumnVector prefixSum() { + return new ColumnVector(prefixSum(getNativeView())); + } + ///////////////////////////////////////////////////////////////////////////// // LOGICAL ///////////////////////////////////////////////////////////////////////////// @@ -2910,6 +2922,8 @@ private static native long rollingWindow( long preceding_col, long following_col); + private static native long prefixSum(long viewHandle) throws CudfException; + private static native long nansToNulls(long viewHandle) throws CudfException; private static native long charLengths(long viewHandle) throws CudfException; diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index ea04e615bb6..a4fe3acab08 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -1261,7 +1261,7 @@ public Table repeat(int count) { * @return the new Table. * @throws CudfException on any error. */ - public Table repeat(ColumnVector counts) { + public Table repeat(ColumnView counts) { return repeat(counts, true); } @@ -1276,7 +1276,7 @@ public Table repeat(ColumnVector counts) { * @return the new Table. * @throws CudfException on any error. */ - public Table repeat(ColumnVector counts, boolean checkCount) { + public Table repeat(ColumnView counts, boolean checkCount) { return new Table(repeatColumnCount(this.nativeHandle, counts.getNativeView(), checkCount)); } @@ -1719,7 +1719,7 @@ private int[] copyAndValidate(int[] indices) { * @return table containing copy of all elements of this table passing * the filter defined by the boolean mask */ - public Table filter(ColumnVector mask) { + public Table filter(ColumnView mask) { assert mask.getType().equals(DType.BOOL8) : "Mask column must be of type BOOL8"; assert getRowCount() == 0 || getRowCount() == mask.getRowCount() : "Mask column has incorrect size"; return new Table(filter(nativeHandle, mask.getNativeView())); @@ -1955,7 +1955,7 @@ public ColumnVector rowBitCount() { * @param gatherMap the map of indexes. Must be non-nullable and integral type. * @return the resulting Table. */ - public Table gather(ColumnVector gatherMap) { + public Table gather(ColumnView gatherMap) { return gather(gatherMap, true); } @@ -1973,7 +1973,7 @@ public Table gather(ColumnVector gatherMap) { * when setting this to false. * @return the resulting Table. */ - public Table gather(ColumnVector gatherMap, boolean checkBounds) { + public Table gather(ColumnView gatherMap, boolean checkBounds) { return new Table(gather(nativeHandle, gatherMap.getNativeView(), checkBounds)); } @@ -2191,7 +2191,7 @@ public ColumnVector[] convertToRows() { * @param schema the types of each column. * @return the parsed table. */ - public static Table convertFromRows(ColumnVector vec, DType ... schema) { + public static Table convertFromRows(ColumnView vec, DType ... schema) { // TODO at some point we need a schema that support nesting so we can support nested types // TODO we will need scale at some point very soon too int[] types = new int[schema.length]; diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index 17776288b49..179a6936d8b 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -248,6 +248,7 @@ set(SOURCE_FILES "src/RmmJni.cpp" "src/ScalarJni.cpp" "src/TableJni.cpp" + "src/prefix_sum.cu" "src/map_lookup.cu") add_library(cudfjni SHARED ${SOURCE_FILES}) diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index cec3a1a92a6..c9bafa5abee 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -64,6 +64,7 @@ #include #include "cudf/types.hpp" +#include "prefix_sum.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" #include "jni.h" @@ -1755,6 +1756,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_makeStructView(JNIEnv *en CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_prefixSum(JNIEnv *env, jobject j_object, + jlong handle) { + + JNI_NULL_CHECK(env, handle, "native view handle is null", 0) + + try { + cudf::jni::auto_set_device(env); + cudf::column_view *view = reinterpret_cast(handle); + std::unique_ptr result = cudf::jni::prefix_sum(*view); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0) +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_nansToNulls(JNIEnv *env, jobject j_object, jlong handle) { @@ -1779,6 +1794,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_nansToNulls(JNIEnv *env, CATCH_STD(env, 0) } + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_isFloat(JNIEnv *env, jobject j_object, jlong handle) { diff --git a/java/src/main/native/src/prefix_sum.cu b/java/src/main/native/src/prefix_sum.cu new file mode 100644 index 00000000000..e3c53696185 --- /dev/null +++ b/java/src/main/native/src/prefix_sum.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include + + +namespace cudf { +namespace jni { + +std::unique_ptr prefix_sum(column_view const &value_column, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) { + // Defensive checks. + CUDF_EXPECTS(value_column.type().id() == type_id::INT64, "Only longs are supported."); + CUDF_EXPECTS(!value_column.has_nulls(), "NULLS are not supported"); + + auto result = make_numeric_column(value_column.type(), value_column.size(), + mask_state::ALL_VALID, stream, mr); + + thrust::inclusive_scan(rmm::exec_policy(stream), + value_column.begin(), + value_column.end(), + result->mutable_view().begin()); + + return result; +} +} // namespace jni +} // namespace cudf diff --git a/java/src/main/native/src/prefix_sum.hpp b/java/src/main/native/src/prefix_sum.hpp new file mode 100644 index 00000000000..8f39f9a8c69 --- /dev/null +++ b/java/src/main/native/src/prefix_sum.hpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace cudf { + +namespace jni { + +/** + * @brief compute the prefix sum of a column of longs + */ +std::unique_ptr +prefix_sum(column_view const &value_column, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()); + +} // namespace jni + +} // namespace cudf diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 36123704ae6..59b2174c867 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1994,6 +1994,34 @@ void testStringConcatSeparators() { } } + @Test + void testPrefixSum() { + try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10); + ColumnVector summed = v1.prefixSum(); + ColumnVector expected = ColumnVector.fromLongs(1, 3, 6, 11, 19, 29)) { + assertColumnsAreEqual(expected, summed); + } + } + + @Test + void testPrefixSumErrors() { + try (ColumnVector v1 = ColumnVector.fromBoxedLongs(1L, 2L, 3L, 5L, 8L, null)) { + assertThrows(CudfException.class, () -> { + try(ColumnVector ignored = v1.prefixSum()) { + // empty + } + }); + } + + try (ColumnVector v1 = ColumnVector.fromInts(1, 2, 3, 5, 8, 10)) { + assertThrows(CudfException.class, () -> { + try(ColumnVector ignored = v1.prefixSum()) { + // empty + } + }); + } + } + @Test void testWindowStatic() { WindowOptions options = WindowOptions.builder().window(2, 1)