Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some APIs to help with out of core joins in Spark [skip ci] #8118

Merged
merged 3 commits into from
May 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,18 @@ public final ColumnVector rollingWindow(Aggregation op, WindowOptions options) {
}
}

/**
* Compute the cumulative sum/prefix sum of the values in this column.
* This is similar to a rolling window SUM with unbounded preceding and none following.
* Input values 1, 2, 3
* Output values 1, 3, 6
* This currently only works for long values that are not nullable as this is currently a
* very simple implementation. It may be expanded in the future if needed.
*/
public final ColumnVector prefixSum() {
return new ColumnVector(prefixSum(getNativeView()));
}

/////////////////////////////////////////////////////////////////////////////
// LOGICAL
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2910,6 +2922,8 @@ private static native long rollingWindow(
long preceding_col,
long following_col);

private static native long prefixSum(long viewHandle) throws CudfException;

private static native long nansToNulls(long viewHandle) throws CudfException;

private static native long charLengths(long viewHandle) throws CudfException;
Expand Down
12 changes: 6 additions & 6 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ public Table repeat(int count) {
* @return the new Table.
* @throws CudfException on any error.
*/
public Table repeat(ColumnVector counts) {
public Table repeat(ColumnView counts) {
return repeat(counts, true);
}

Expand All @@ -1276,7 +1276,7 @@ public Table repeat(ColumnVector counts) {
* @return the new Table.
* @throws CudfException on any error.
*/
public Table repeat(ColumnVector counts, boolean checkCount) {
public Table repeat(ColumnView counts, boolean checkCount) {
return new Table(repeatColumnCount(this.nativeHandle, counts.getNativeView(), checkCount));
}

Expand Down Expand Up @@ -1719,7 +1719,7 @@ private int[] copyAndValidate(int[] indices) {
* @return table containing copy of all elements of this table passing
* the filter defined by the boolean mask
*/
public Table filter(ColumnVector mask) {
public Table filter(ColumnView mask) {
assert mask.getType().equals(DType.BOOL8) : "Mask column must be of type BOOL8";
assert getRowCount() == 0 || getRowCount() == mask.getRowCount() : "Mask column has incorrect size";
return new Table(filter(nativeHandle, mask.getNativeView()));
Expand Down Expand Up @@ -1955,7 +1955,7 @@ public ColumnVector rowBitCount() {
* @param gatherMap the map of indexes. Must be non-nullable and integral type.
* @return the resulting Table.
*/
public Table gather(ColumnVector gatherMap) {
public Table gather(ColumnView gatherMap) {
return gather(gatherMap, true);
}

Expand All @@ -1973,7 +1973,7 @@ public Table gather(ColumnVector gatherMap) {
* when setting this to false.
* @return the resulting Table.
*/
public Table gather(ColumnVector gatherMap, boolean checkBounds) {
public Table gather(ColumnView gatherMap, boolean checkBounds) {
return new Table(gather(nativeHandle, gatherMap.getNativeView(), checkBounds));
}

Expand Down Expand Up @@ -2191,7 +2191,7 @@ public ColumnVector[] convertToRows() {
* @param schema the types of each column.
* @return the parsed table.
*/
public static Table convertFromRows(ColumnVector vec, DType ... schema) {
public static Table convertFromRows(ColumnView vec, DType ... schema) {
// TODO at some point we need a schema that support nesting so we can support nested types
// TODO we will need scale at some point very soon too
int[] types = new int[schema.length];
Expand Down
1 change: 1 addition & 0 deletions java/src/main/native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ set(SOURCE_FILES
"src/RmmJni.cpp"
"src/ScalarJni.cpp"
"src/TableJni.cpp"
"src/prefix_sum.cu"
"src/map_lookup.cu")
add_library(cudfjni SHARED ${SOURCE_FILES})

Expand Down
16 changes: 16 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include <map_lookup.hpp>
#include "cudf/types.hpp"

#include "prefix_sum.hpp"
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
#include "jni.h"
Expand Down Expand Up @@ -1755,6 +1756,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_makeStructView(JNIEnv *en
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_prefixSum(JNIEnv *env, jobject j_object,
jlong handle) {

JNI_NULL_CHECK(env, handle, "native view handle is null", 0)

try {
cudf::jni::auto_set_device(env);
cudf::column_view *view = reinterpret_cast<cudf::column_view *>(handle);
std::unique_ptr<cudf::column> result = cudf::jni::prefix_sum(*view);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0)
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_nansToNulls(JNIEnv *env, jobject j_object,
jlong handle) {

Expand All @@ -1779,6 +1794,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_nansToNulls(JNIEnv *env,
CATCH_STD(env, 0)
}


JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_isFloat(JNIEnv *env, jobject j_object,
jlong handle) {

Expand Down
48 changes: 48 additions & 0 deletions java/src/main/native/src/prefix_sum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <thrust/scan.h>

#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/exec_policy.hpp>


namespace cudf {
namespace jni {

std::unique_ptr<column> prefix_sum(column_view const &value_column,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
CUDF_EXPECTS(value_column.type().id() == type_id::INT64, "Only longs are supported.");
CUDF_EXPECTS(!value_column.has_nulls(), "NULLS are not supported");

auto result = make_numeric_column(value_column.type(), value_column.size(),
mask_state::ALL_VALID, stream, mr);

thrust::inclusive_scan(rmm::exec_policy(stream),
value_column.begin<int64_t>(),
value_column.end<int64_t>(),
result->mutable_view().begin<int64_t>());

return result;
}
} // namespace jni
} // namespace cudf
36 changes: 36 additions & 0 deletions java/src/main/native/src/prefix_sum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/column/column.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace cudf {

namespace jni {

/**
* @brief compute the prefix sum of a column of longs
*/
std::unique_ptr<column>
prefix_sum(column_view const &value_column,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());

} // namespace jni

} // namespace cudf
28 changes: 28 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1994,6 +1994,34 @@ void testStringConcatSeparators() {
}
}

@Test
void testPrefixSum() {
try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10);
ColumnVector summed = v1.prefixSum();
ColumnVector expected = ColumnVector.fromLongs(1, 3, 6, 11, 19, 29)) {
assertColumnsAreEqual(expected, summed);
}
}

@Test
void testPrefixSumErrors() {
try (ColumnVector v1 = ColumnVector.fromBoxedLongs(1L, 2L, 3L, 5L, 8L, null)) {
assertThrows(CudfException.class, () -> {
try(ColumnVector ignored = v1.prefixSum()) {
// empty
}
});
}

try (ColumnVector v1 = ColumnVector.fromInts(1, 2, 3, 5, 8, 10)) {
assertThrows(CudfException.class, () -> {
try(ColumnVector ignored = v1.prefixSum()) {
// empty
}
});
}
}

@Test
void testWindowStatic() {
WindowOptions options = WindowOptions.builder().window(2, 1)
Expand Down