Skip to content

Commit

Permalink
Some APIs to help with out of core joins in Spark (#8118)
Browse files Browse the repository at this point in the history
It adds a simple prefix sum for doing size calculations. It also changes a few APIs so that they take a ColumnView instead of a ColumnVector.

Authors:
  - Robert (Bobby) Evans (https://github.com/revans2)

Approvers:
  - Jason Lowe (https://github.com/jlowe)

URL: #8118
  • Loading branch information
revans2 authored May 1, 2021
1 parent 4869c23 commit cf8c73a
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 6 deletions.
14 changes: 14 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,18 @@ public final ColumnVector rollingWindow(Aggregation op, WindowOptions options) {
}
}

/**
* Compute the cumulative sum/prefix sum of the values in this column.
* This is similar to a rolling window SUM with unbounded preceding and none following.
* Input values 1, 2, 3
* Output values 1, 3, 6
* This currently only works for long values that are not nullable as this is currently a
* very simple implementation. It may be expanded in the future if needed.
*/
public final ColumnVector prefixSum() {
return new ColumnVector(prefixSum(getNativeView()));
}

/////////////////////////////////////////////////////////////////////////////
// LOGICAL
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2919,6 +2931,8 @@ private static native long rollingWindow(
long preceding_col,
long following_col);

private static native long prefixSum(long viewHandle) throws CudfException;

private static native long nansToNulls(long viewHandle) throws CudfException;

private static native long charLengths(long viewHandle) throws CudfException;
Expand Down
12 changes: 6 additions & 6 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ public Table repeat(int count) {
* @return the new Table.
* @throws CudfException on any error.
*/
public Table repeat(ColumnVector counts) {
public Table repeat(ColumnView counts) {
return repeat(counts, true);
}

Expand All @@ -1276,7 +1276,7 @@ public Table repeat(ColumnVector counts) {
* @return the new Table.
* @throws CudfException on any error.
*/
public Table repeat(ColumnVector counts, boolean checkCount) {
public Table repeat(ColumnView counts, boolean checkCount) {
return new Table(repeatColumnCount(this.nativeHandle, counts.getNativeView(), checkCount));
}

Expand Down Expand Up @@ -1719,7 +1719,7 @@ private int[] copyAndValidate(int[] indices) {
* @return table containing copy of all elements of this table passing
* the filter defined by the boolean mask
*/
public Table filter(ColumnVector mask) {
public Table filter(ColumnView mask) {
assert mask.getType().equals(DType.BOOL8) : "Mask column must be of type BOOL8";
assert getRowCount() == 0 || getRowCount() == mask.getRowCount() : "Mask column has incorrect size";
return new Table(filter(nativeHandle, mask.getNativeView()));
Expand Down Expand Up @@ -1955,7 +1955,7 @@ public ColumnVector rowBitCount() {
* @param gatherMap the map of indexes. Must be non-nullable and integral type.
* @return the resulting Table.
*/
public Table gather(ColumnVector gatherMap) {
public Table gather(ColumnView gatherMap) {
return gather(gatherMap, true);
}

Expand All @@ -1973,7 +1973,7 @@ public Table gather(ColumnVector gatherMap) {
* when setting this to false.
* @return the resulting Table.
*/
public Table gather(ColumnVector gatherMap, boolean checkBounds) {
public Table gather(ColumnView gatherMap, boolean checkBounds) {
return new Table(gather(nativeHandle, gatherMap.getNativeView(), checkBounds));
}

Expand Down Expand Up @@ -2191,7 +2191,7 @@ public ColumnVector[] convertToRows() {
* @param schema the types of each column.
* @return the parsed table.
*/
public static Table convertFromRows(ColumnVector vec, DType ... schema) {
public static Table convertFromRows(ColumnView vec, DType ... schema) {
// TODO at some point we need a schema that support nesting so we can support nested types
// TODO we will need scale at some point very soon too
int[] types = new int[schema.length];
Expand Down
1 change: 1 addition & 0 deletions java/src/main/native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ set(SOURCE_FILES
"src/RmmJni.cpp"
"src/ScalarJni.cpp"
"src/TableJni.cpp"
"src/prefix_sum.cu"
"src/map_lookup.cu")
add_library(cudfjni SHARED ${SOURCE_FILES})

Expand Down
16 changes: 16 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include <map_lookup.hpp>
#include "cudf/types.hpp"

#include "prefix_sum.hpp"
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
#include "jni.h"
Expand Down Expand Up @@ -1755,6 +1756,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_makeStructView(JNIEnv *en
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_prefixSum(JNIEnv *env, jobject j_object,
jlong handle) {

JNI_NULL_CHECK(env, handle, "native view handle is null", 0)

try {
cudf::jni::auto_set_device(env);
cudf::column_view *view = reinterpret_cast<cudf::column_view *>(handle);
std::unique_ptr<cudf::column> result = cudf::jni::prefix_sum(*view);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0)
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_nansToNulls(JNIEnv *env, jobject j_object,
jlong handle) {

Expand All @@ -1779,6 +1794,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_nansToNulls(JNIEnv *env,
CATCH_STD(env, 0)
}


JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_isFloat(JNIEnv *env, jobject j_object,
jlong handle) {

Expand Down
48 changes: 48 additions & 0 deletions java/src/main/native/src/prefix_sum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <thrust/scan.h>

#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/exec_policy.hpp>


namespace cudf {
namespace jni {

std::unique_ptr<column> prefix_sum(column_view const &value_column,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
CUDF_EXPECTS(value_column.type().id() == type_id::INT64, "Only longs are supported.");
CUDF_EXPECTS(!value_column.has_nulls(), "NULLS are not supported");

auto result = make_numeric_column(value_column.type(), value_column.size(),
mask_state::ALL_VALID, stream, mr);

thrust::inclusive_scan(rmm::exec_policy(stream),
value_column.begin<int64_t>(),
value_column.end<int64_t>(),
result->mutable_view().begin<int64_t>());

return result;
}
} // namespace jni
} // namespace cudf
36 changes: 36 additions & 0 deletions java/src/main/native/src/prefix_sum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/column/column.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace cudf {

namespace jni {

/**
* @brief compute the prefix sum of a column of longs
*/
std::unique_ptr<column>
prefix_sum(column_view const &value_column,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());

} // namespace jni

} // namespace cudf
28 changes: 28 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2058,6 +2058,34 @@ void testStringConcatSeparators() {
}
}

@Test
void testPrefixSum() {
try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10);
ColumnVector summed = v1.prefixSum();
ColumnVector expected = ColumnVector.fromLongs(1, 3, 6, 11, 19, 29)) {
assertColumnsAreEqual(expected, summed);
}
}

@Test
void testPrefixSumErrors() {
try (ColumnVector v1 = ColumnVector.fromBoxedLongs(1L, 2L, 3L, 5L, 8L, null)) {
assertThrows(CudfException.class, () -> {
try(ColumnVector ignored = v1.prefixSum()) {
// empty
}
});
}

try (ColumnVector v1 = ColumnVector.fromInts(1, 2, 3, 5, 8, 10)) {
assertThrows(CudfException.class, () -> {
try(ColumnVector ignored = v1.prefixSum()) {
// empty
}
});
}
}

@Test
void testWindowStatic() {
WindowOptions options = WindowOptions.builder().window(2, 1)
Expand Down

0 comments on commit cf8c73a

Please sign in to comment.