Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI: Refactor the code of making column from scalar [skip ci] #8310

Merged
merged 1 commit into from
May 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 4 additions & 39 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,49 +220,14 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, j
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env, jclass,
jlong j_scalar,
jint row_count) {
using ScalarType = cudf::scalar_type_t<cudf::size_type>;
JNI_NULL_CHECK(env, j_scalar, "scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
auto scalar_val = reinterpret_cast<cudf::scalar const *>(j_scalar);
auto dtype = scalar_val->type();
cudf::mask_state mask_state =
scalar_val->is_valid() ? cudf::mask_state::UNALLOCATED : cudf::mask_state::ALL_NULL;
std::unique_ptr<cudf::column> col;
if (dtype.id() == cudf::type_id::LIST) {
// Neither 'cudf::make_empty_column' nor 'cudf::make_column_from_scalar' supports
// LIST type for now (https://github.com/rapidsai/cudf/issues/8088), so the list
// precedes the others and takes care of the empty column itself.
auto s_list = reinterpret_cast<cudf::list_scalar const *>(scalar_val);
cudf::column_view s_val = s_list->view();

// Offsets: [0, list_size, list_size*2, ..., list_szie*row_count]
auto zero = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32));
auto step = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32));
zero->set_valid(true);
step->set_valid(true);
static_cast<ScalarType *>(zero.get())->set_value(0);
static_cast<ScalarType *>(step.get())->set_value(s_val.size());
std::unique_ptr<cudf::column> offsets = cudf::sequence(row_count + 1, *zero, *step);
// Data:
// Builds the data column by leveraging `cudf::concatenate` to repeat the 's_val'
// 'row_count' times, because 'cudf::make_column_from_scalar' does not support list
// type.
// (Assumes the `row_count` is not big, otherwise there would be a performance issue.)
// Checks the `row_count` because `cudf::concatenate` does not support no rows.
auto data_col = row_count > 0
? cudf::concatenate(std::vector<cudf::column_view>(row_count, s_val))
: cudf::empty_like(s_val);
jlowe marked this conversation as resolved.
Show resolved Hide resolved
col = cudf::make_lists_column(row_count, std::move(offsets), std::move(data_col),
cudf::state_null_count(mask_state, row_count),
cudf::create_null_mask(row_count, mask_state));
} else if (row_count == 0) {
col = cudf::make_empty_column(dtype);
} else if (cudf::is_fixed_width(dtype)) {
col = cudf::make_fixed_width_column(dtype, row_count, mask_state);
auto mut_view = col->mutable_view();
cudf::fill_in_place(mut_view, 0, row_count, *scalar_val);
} else if (dtype.id() == cudf::type_id::STRING) {
if (scalar_val->type().id() == cudf::type_id::STRING) {
// Tests fail when using the cudf implementation, complaining no child for string column.
// So here take care of the String type itself.
// create a string column of all empty strings to fill (cheapest string column to create)
auto offsets = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, row_count + 1,
cudf::mask_state::UNALLOCATED);
Expand All @@ -273,7 +238,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env,

col = cudf::fill(str_col->view(), 0, row_count, *scalar_val);
} else {
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid data type", 0);
col = cudf::make_column_from_scalar(*scalar_val, row_count);
}
return reinterpret_cast<jlong>(col.release());
}
Expand Down