Skip to content

Commit

Permalink
JNI support for scalar of list (#8077)
Browse files Browse the repository at this point in the history
This PR is to add the JNI support for scalar of list, along with building a ColumnVector from a list scalar.

Since the PR #7584 inroduced the list scalar in cpp.

Signed-off-by: Firestarman <[email protected]>

Authors:
  - Liangcai Li (https://github.com/firestarman)

Approvers:
  - Jason Lowe (https://github.com/jlowe)

URL: #8077
  • Loading branch information
firestarman authored Apr 29, 2021
1 parent 1757d10 commit 04d6e5a
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 12 deletions.
11 changes: 10 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class ColumnView implements AutoCloseable, BinaryOperable {
* Constructs a Column View given a native view address
* @param address the view handle
*/
protected ColumnView(long address) {
ColumnView(long address) {
this.viewHandle = address;
this.type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle));
this.rows = ColumnView.getNativeRowCount(viewHandle);
Expand Down Expand Up @@ -211,6 +211,15 @@ public void close() {
viewHandle = 0;
}

@Override
public String toString() {
return "ColumnView{" +
"rows=" + rows +
", type=" + type +
", nullCount=" + nullCount +
'}';
}

/**
* Used for string strip function.
* Indicates characters to be stripped from the beginning, end, or both of each string.
Expand Down
46 changes: 46 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ public static Scalar fromNull(DType type) {
return new Scalar(type, makeDecimal32Scalar(0, type.getScale(), false));
case DECIMAL64:
return new Scalar(type, makeDecimal64Scalar(0L, type.getScale(), false));
case LIST:
return new Scalar(type, makeListScalar(0L, false));
default:
throw new IllegalArgumentException("Unexpected type: " + type);
}
Expand Down Expand Up @@ -333,6 +335,19 @@ public static Scalar fromString(String value) {
return new Scalar(DType.STRING, makeStringScalar(value.getBytes(StandardCharsets.UTF_8), true));
}

/**
* Creates a scalar of list from a ColumnView.
*
* All the rows in the ColumnView will be copied into the Scalar. So the ColumnView
* can be closed after this call completes.
*/
public static Scalar listFromColumnView(ColumnView list) {
if (list == null) {
return Scalar.fromNull(DType.LIST);
}
return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true));
}

private static native void closeScalar(long scalarHandle);
private static native boolean isScalarValid(long scalarHandle);
private static native byte getByte(long scalarHandle);
Expand All @@ -342,6 +357,7 @@ public static Scalar fromString(String value) {
private static native float getFloat(long scalarHandle);
private static native double getDouble(long scalarHandle);
private static native byte[] getUTF8(long scalarHandle);
private static native long getListAsColumnView(long scalarHandle);
private static native long makeBool8Scalar(boolean isValid, boolean value);
private static native long makeInt8Scalar(byte value, boolean isValid);
private static native long makeUint8Scalar(byte value, boolean isValid);
Expand All @@ -360,6 +376,7 @@ public static Scalar fromString(String value) {
private static native long makeTimestampTimeScalar(int dtypeNativeId, long value, boolean isValid);
private static native long makeDecimal32Scalar(int value, int scale, boolean isValid);
private static native long makeDecimal64Scalar(long value, int scale, boolean isValid);
private static native long makeListScalar(long viewHandle, boolean isValid);


Scalar(DType type, long scalarHandle) {
Expand Down Expand Up @@ -484,6 +501,19 @@ public byte[] getUTF8() {
return getUTF8(getScalarHandle());
}

/**
* Returns the scalar value as a ColumnView. Callers should close the returned ColumnView to
* avoid memory leak.
*
* The returned ColumnView is only valid as long as the Scalar remains valid. If the Scalar
* is closed before this ColumnView is closed, using this ColumnView will result in undefined
* behavior.
*/
public ColumnView getListAsColumnView() {
assert DType.LIST.equals(type) : "Cannot get list for the vector of type " + type;
return new ColumnView(getListAsColumnView(getScalarHandle()));
}

@Override
public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) {
if (rhs instanceof ColumnView) {
Expand Down Expand Up @@ -541,6 +571,11 @@ public boolean equals(Object o) {
return getLong() == other.getLong();
case STRING:
return Arrays.equals(getUTF8(), other.getUTF8());
case LIST:
try (ColumnView viewMe = getListAsColumnView();
ColumnView viewO = other.getListAsColumnView()) {
return viewMe.equals(viewO);
}
default:
throw new IllegalStateException("Unexpected type: " + type);
}
Expand Down Expand Up @@ -589,6 +624,11 @@ public int hashCode() {
case STRING:
valueHash = Arrays.hashCode(getUTF8());
break;
case LIST:
try (ColumnView v = getListAsColumnView()) {
valueHash = v.hashCode();
}
break;
default:
throw new IllegalStateException("Unknown scalar type: " + type);
}
Expand Down Expand Up @@ -651,6 +691,12 @@ public String toString() {
case DECIMAL64:
sb.append(getBigDecimal());
break;
case LIST:
try (ColumnView v = getListAsColumnView()) {
// It's not easy to pull out the elements so just a simple string of some metadata.
sb.append(v.toString());
}
break;
default:
throw new IllegalArgumentException("Unknown scalar type: " + type);
}
Expand Down
79 changes: 78 additions & 1 deletion java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,55 @@
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"

namespace cudf {
namespace jni {
/**
* @brief Creates an empty column according to the type tree specified
* in @p view
*
* An empty column contains zero elements and no validity mask.
*
* Unlike the 'cudf::make_empty_column', it takes care of the nested type by
* iterating the children columns in the @p view
*
* @param[in] view The input column view
* @return An empty column for the input column view
*/
std::unique_ptr<cudf::column> make_empty_column(JNIEnv *env, cudf::column_view const& view) {
auto tid = view.type().id();
if (tid == cudf::type_id::LIST) {
// List
if (view.num_children() != 2) {
throw jni_exception("List type requires two children(offset, data).");
}
// Only needs the second child.
auto data_view = view.child(1);
// offsets: [0]
auto offsets_buffer = rmm::device_buffer(sizeof(cudf::size_type));
device_memset_async(env, offsets_buffer, 0);
auto offsets = std::make_unique<column>(cudf::data_type{cudf::type_id::INT32}, 1,
std::move(offsets_buffer),
rmm::device_buffer(), 0);
auto data_col = make_empty_column(env, data_view);
return cudf::make_lists_column(0, std::move(offsets), std::move(data_col),
0, rmm::device_buffer());
} else if (tid == cudf::type_id::STRUCT) {
// Struct
std::vector<std::unique_ptr<column>> children(view.num_children());
std::transform(view.child_begin(), view.child_end(), children.begin(),
[env](auto const& child_v) {
return make_empty_column(env, child_v);
});
return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer());
} else {
// Non nested types
return cudf::make_empty_column(view.type());
}
}

} // namespace jni
} // namespace cudf

extern "C" {

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass,
Expand Down Expand Up @@ -168,6 +217,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, j
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env, jclass,
jlong j_scalar,
jint row_count) {
using ScalarType = cudf::scalar_type_t<cudf::size_type>;
JNI_NULL_CHECK(env, j_scalar, "scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
Expand All @@ -176,7 +226,34 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env,
cudf::mask_state mask_state =
scalar_val->is_valid() ? cudf::mask_state::UNALLOCATED : cudf::mask_state::ALL_NULL;
std::unique_ptr<cudf::column> col;
if (row_count == 0) {
if (dtype.id() == cudf::type_id::LIST) {
// Neither 'cudf::make_empty_column' nor 'cudf::make_column_from_scalar' supports
// LIST type for now (https://github.com/rapidsai/cudf/issues/8088), so the list
// precedes the others and takes care of the empty column itself.
auto s_list = reinterpret_cast<cudf::list_scalar const *>(scalar_val);
cudf::column_view s_val = s_list->view();

// Offsets: [0, list_size, list_size*2, ..., list_szie*row_count]
auto zero = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32));
auto step = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32));
zero->set_valid(true);
step->set_valid(true);
static_cast<ScalarType *>(zero.get())->set_value(0);
static_cast<ScalarType *>(step.get())->set_value(s_val.size());
std::unique_ptr<cudf::column> offsets = cudf::sequence(row_count + 1, *zero, *step);
// Data:
// Builds the data column by leveraging `cudf::concatenate` to repeat the 's_val'
// 'row_count' times, because 'cudf::make_column_from_scalar' does not support list
// type.
// (Assumes the `row_count` is not big, otherwise there would be a performance issue.)
// Checks the `row_count` because `cudf::concatenate` does not support no columns.
auto data_col = row_count > 0
? cudf::concatenate(std::vector<cudf::column_view>(row_count, s_val))
: cudf::jni::make_empty_column(env, s_val);
col = cudf::make_lists_column(row_count, std::move(offsets), std::move(data_col),
cudf::state_null_count(mask_state, row_count),
cudf::create_null_mask(row_count, mask_state));
} else if (row_count == 0) {
col = cudf::make_empty_column(dtype);
} else if (cudf::is_fixed_width(dtype)) {
col = cudf::make_fixed_width_column(dtype, row_count, mask_state);
Expand Down
9 changes: 8 additions & 1 deletion java/src/main/native/src/CudaJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <rmm/device_buffer.hpp>
#include "jni_utils.hpp"

namespace {
Expand Down Expand Up @@ -47,6 +48,12 @@ void auto_set_device(JNIEnv *env) {
}
}

/** Fills all the bytes in the buffer 'buf' with 'value'. */
void device_memset_async(JNIEnv *env, rmm::device_buffer& buf, char value) {
cudaError_t cuda_status = cudaMemsetAsync((void *)buf.data(), value, buf.size());
jni_cuda_check(env, cuda_status);
}

} // namespace jni
} // namespace cudf

Expand Down
34 changes: 33 additions & 1 deletion java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -125,6 +125,19 @@ JNIEXPORT jbyteArray JNICALL Java_ai_rapids_cudf_Scalar_getUTF8(JNIEnv *env, jcl
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_getListAsColumnView(JNIEnv *env, jclass,
jlong scalar_handle) {
JNI_NULL_CHECK(env, scalar_handle, "scalar handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto s = reinterpret_cast<cudf::list_scalar *>(scalar_handle);
// Creates a column view in heap with the stack one, to let JVM take care of its
// life cycle.
return reinterpret_cast<jlong>(new cudf::column_view(s->view()));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeBool8Scalar(JNIEnv *env, jclass,
jboolean value,
jboolean is_valid) {
Expand Down Expand Up @@ -445,4 +458,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_binaryOpSV(JNIEnv *env, jclas
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, jclass,
jlong view_handle,
jboolean is_valid) {
if (is_valid) {
JNI_NULL_CHECK(env, view_handle,
"list scalar is set to `valid` but column view is null", 0);
}
try {
cudf::jni::auto_set_device(env);
auto col_view = reinterpret_cast<cudf::column_view *>(view_handle);

cudf::scalar* s = is_valid ? new cudf::list_scalar(*col_view)
: new cudf::list_scalar();
s->set_valid(is_valid);
return reinterpret_cast<jlong>(s);
}
CATCH_STD(env, 0);
}

} // extern "C"
7 changes: 7 additions & 0 deletions java/src/main/native/src/cudf_jni_apis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,12 @@ void set_cudf_device(int device);
*/
void auto_set_device(JNIEnv *env);

/**
* Fills all the bytes in the buffer 'buf' with 'value'.
* The operation has not necessarily completed when this returns, but it could overlap with
* operations occurring on other streams.
*/
void device_memset_async(JNIEnv *env, rmm::device_buffer& buf, char value);

} // namespace jni
} // namespace cudf
Loading

0 comments on commit 04d6e5a

Please sign in to comment.