From 7c69daec637bbd3924f319e3f58823a6a2551859 Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 27 Jan 2022 15:08:11 -0800 Subject: [PATCH] Accept r-value references in convert_table_for_return(): (#10131) `cudf::jni::convert_table_for_return()` is usually used on tables returned from a libcudf API call. It currently requires an l-value reference for its table argument. This necessitates parking the result of libcudf call in an avoidable temp variable. This commit adds the option to use an r-value reference. This allows table expressions to be used directly, reducing clutter. Note: 1. The previous signature is retained, because not all call sites can use the r-value interface cleanly. (E.g. when the libcudf call is complex.) 2. The third argument (vector>) has been converted from l-ref to r-ref, so that an empty default can be introduced. This commit also includes minor code cleanup in the periphery of calls to `convert_table_for_return()`. Authors: - MithunR (https://github.com/mythrocks) Approvers: - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/10131 --- java/src/main/native/include/jni_utils.hpp | 3 + java/src/main/native/src/ColumnViewJni.cpp | 22 +- java/src/main/native/src/TableJni.cpp | 243 ++++++++------------- java/src/main/native/src/cudf_jni_apis.hpp | 23 +- 4 files changed, 131 insertions(+), 160 deletions(-) diff --git a/java/src/main/native/include/jni_utils.hpp b/java/src/main/native/include/jni_utils.hpp index 317ef152492..a45716a89b3 100644 --- a/java/src/main/native/include/jni_utils.hpp +++ b/java/src/main/native/include/jni_utils.hpp @@ -395,6 +395,9 @@ template class native_jpointerArray { T **data() { return reinterpret_cast(wrapped.data()); } + T *const *begin() const { return data(); } + T *const *end() const { return data() + size(); } + const jlongArray get_jArray() const { return wrapped.get_jArray(); } jlongArray get_jArray() { return wrapped.get_jArray(); } diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index fe1173e417e..0fce27bc130 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -561,18 +561,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listSortRows(JNIEnv *env, JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, jlong column_view, - jlong delimiter, + jlong delimiter_ptr, jint max_split) { JNI_NULL_CHECK(env, column_view, "column is null", 0); - JNI_NULL_CHECK(env, delimiter, "string scalar delimiter is null", 0); + JNI_NULL_CHECK(env, delimiter_ptr, "string scalar delimiter is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(column_view); - cudf::strings_column_view scv(*cv); - cudf::string_scalar *ss_scalar = reinterpret_cast(delimiter); + cudf::strings_column_view const scv{*reinterpret_cast(column_view)}; + auto delimiter = reinterpret_cast(delimiter_ptr); - std::unique_ptr table_result = cudf::strings::split(scv, *ss_scalar, max_split); - return cudf::jni::convert_table_for_return(env, table_result); + return cudf::jni::convert_table_for_return(env, + cudf::strings::split(scv, *delimiter, max_split)); } CATCH_STD(env, 0); } @@ -1410,13 +1409,12 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_extractRe(JNIEnv *en try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); + cudf::strings_column_view const strings_column{ + *reinterpret_cast(j_view_handle)}; cudf::jni::native_jstring pattern(env, patternObj); - std::unique_ptr table_result = - cudf::strings::extract(strings_column, pattern.get()); - return cudf::jni::convert_table_for_return(env, table_result); + return cudf::jni::convert_table_for_return( + env, cudf::strings::extract(strings_column, pattern.get())); } CATCH_STD(env, 0); } diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 10f295e27bf..aeac1856db0 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -599,37 +599,27 @@ class native_arrow_ipc_reader_handle final { void close() { source->Close(); } }; -/** - * Take a table returned by some operation and turn it into an array of column* so we can track them - * ourselves in java instead of having their life tied to the table. - * @param table_result the table to convert for return - * @param extra_columns columns not in the table that will be added to the result at the end. - */ -static jlongArray -convert_table_for_return(JNIEnv *env, std::unique_ptr &table_result, - std::vector> &extra_columns) { +jlongArray convert_table_for_return(JNIEnv *env, std::unique_ptr &&table_result, + std::vector> &&extra_columns) { std::vector> ret = table_result->release(); int table_cols = ret.size(); int num_columns = table_cols + extra_columns.size(); cudf::jni::native_jlongArray outcol_handles(env, num_columns); - for (int i = 0; i < table_cols; i++) { - outcol_handles[i] = release_as_jlong(ret[i]); - } - for (size_t i = 0; i < extra_columns.size(); i++) { - outcol_handles[i + table_cols] = release_as_jlong(extra_columns[i]); - } + std::transform(ret.begin(), ret.end(), outcol_handles.begin(), + [](auto &col) { return release_as_jlong(col); }); + std::transform(extra_columns.begin(), extra_columns.end(), outcol_handles.begin() + table_cols, + [](auto &col) { return release_as_jlong(col); }); return outcol_handles.get_jArray(); } -jlongArray convert_table_for_return(JNIEnv *env, std::unique_ptr &table_result) { - std::vector> extra; - return convert_table_for_return(env, table_result, extra); +jlongArray convert_table_for_return(JNIEnv *env, std::unique_ptr &table_result, + std::vector> &&extra_columns) { + return convert_table_for_return(env, std::move(table_result), std::move(extra_columns)); } jlongArray convert_table_for_return(JNIEnv *env, std::unique_ptr &first_table, std::unique_ptr &second_table) { - std::vector> second_tmp = second_table->release(); - return convert_table_for_return(env, first_table, second_tmp); + return convert_table_for_return(env, first_table, second_table->release()); } // Convert the JNI boolean array of key column sort order to a vector of cudf::order @@ -1068,6 +1058,7 @@ cudf::table_view remove_validity_if_needed(cudf::table_view *input_table_view) { } // namespace jni } // namespace cudf +using cudf::jni::convert_table_for_return; using cudf::jni::ptr_as_jlong; using cudf::jni::release_as_jlong; @@ -1223,9 +1214,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jcla std::vector sort_keys = n_sort_keys_columns.get_dereferenced(); auto sorted_col = cudf::sorted_order(cudf::table_view{sort_keys}, order, null_order); - cudf::table_view *input_table = reinterpret_cast(j_input_table); - std::unique_ptr result = cudf::gather(*input_table, sorted_col->view()); - return cudf::jni::convert_table_for_return(env, result); + auto const input_table = reinterpret_cast(j_input_table); + return convert_table_for_return(env, cudf::gather(*input_table, sorted_col->view())); } CATCH_STD(env, NULL); } @@ -1267,8 +1257,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_merge(JNIEnv *env, jclass n_are_nulls_smallest.transform_if_else(cudf::null_order::BEFORE, cudf::null_order::AFTER); std::vector tables = n_table_handles.get_dereferenced(); - std::unique_ptr result = cudf::merge(tables, indexes, order, null_order); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, cudf::merge(tables, indexes, order, null_order)); } CATCH_STD(env, NULL); } @@ -1344,8 +1333,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( .comment(comment) .build(); - cudf::io::table_with_metadata result = cudf::io::read_csv(opts); - return cudf::jni::convert_table_for_return(env, result.tbl); + return convert_table_for_return(env, cudf::io::read_csv(opts).tbl); } CATCH_STD(env, NULL); } @@ -1425,7 +1413,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readJSON( // there is no need to re-order columns when inferring schema if (result.metadata.column_names.empty() || n_col_names.size() <= 0) { - return cudf::jni::convert_table_for_return(env, result.tbl); + return convert_table_for_return(env, result.tbl); } else { // json reader will not return the correct column order, // so we need to re-order the column of table according to table meta. @@ -1453,11 +1441,11 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readJSON( if (!match) { // can't find some input column names in table meta, return what json reader reads. - return cudf::jni::convert_table_for_return(env, result.tbl); + return convert_table_for_return(env, result.tbl); } else { auto tbv = result.tbl->view().select(std::move(indices)); auto table = std::make_unique(tbv); - return cudf::jni::convert_table_for_return(env, table); + return convert_table_for_return(env, table); } } } @@ -1501,8 +1489,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet(JNIEnv *env, .convert_strings_to_categories(false) .timestamp_type(cudf::data_type(static_cast(unit))) .build(); - cudf::io::table_with_metadata result = cudf::io::read_parquet(opts); - return cudf::jni::convert_table_for_return(env, result.tbl); + return convert_table_for_return(env, cudf::io::read_parquet(opts).tbl); } CATCH_STD(env, NULL); } @@ -1672,8 +1659,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readORC( .timestamp_type(cudf::data_type(static_cast(unit))) .decimal128_columns(n_dec128_col_names.as_cpp_vector()) .build(); - cudf::io::table_with_metadata result = cudf::io::read_orc(opts); - return cudf::jni::convert_table_for_return(env, result.tbl); + return convert_table_for_return(env, cudf::io::read_orc(opts).tbl); } CATCH_STD(env, NULL); } @@ -1956,8 +1942,7 @@ Java_ai_rapids_cudf_Table_convertArrowTableToCudf(JNIEnv *env, jclass, jlong arr try { cudf::jni::auto_set_device(env); - std::unique_ptr result = cudf::from_arrow(*(handle->get())); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, cudf::from_arrow(*(handle->get()))); } CATCH_STD(env, 0) } @@ -2142,7 +2127,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoin( static_cast(compare_nulls_equal) ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, result); } CATCH_STD(env, NULL); } @@ -2171,7 +2156,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoin( static_cast(compare_nulls_equal) ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, result); } CATCH_STD(env, NULL); } @@ -2706,12 +2691,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_crossJoin(JNIEnv *env, jc try { cudf::jni::auto_set_device(env); - cudf::table_view *n_left_table = reinterpret_cast(left_table); - cudf::table_view *n_right_table = reinterpret_cast(right_table); - - std::unique_ptr result = cudf::cross_join(*n_left_table, *n_right_table); - - return cudf::jni::convert_table_for_return(env, result); + auto const left = reinterpret_cast(left_table); + auto const right = reinterpret_cast(right_table); + return convert_table_for_return(env, cudf::cross_join(*left, *right)); } CATCH_STD(env, NULL); } @@ -2734,18 +2716,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env, try { cudf::jni::auto_set_device(env); cudf::jni::native_jpointerArray tables(env, table_handles); - - int num_tables = tables.size(); - // There are some issues with table_view and std::vector. We cannot give the - // vector a size or it will not compile. - std::vector to_concat; - to_concat.reserve(num_tables); - for (int i = 0; i < num_tables; i++) { - JNI_NULL_CHECK(env, tables[i], "input table included a null", NULL); - to_concat.push_back(*tables[i]); - } - std::unique_ptr table_result = cudf::concatenate(to_concat); - return cudf::jni::convert_table_for_return(env, table_result); + std::vector const to_concat = tables.get_dereferenced(); + return convert_table_for_return(env, cudf::concatenate(to_concat)); } CATCH_STD(env, NULL); } @@ -2763,20 +2735,19 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_partition(JNIEnv *env, jc try { cudf::jni::auto_set_device(env); - cudf::table_view *n_input_table = reinterpret_cast(input_table); - cudf::column_view *n_part_column = reinterpret_cast(partition_column); - cudf::jni::native_jintArray n_output_offsets(env, output_offsets); + auto const n_input_table = reinterpret_cast(input_table); + auto const n_part_column = reinterpret_cast(partition_column); - auto result = cudf::partition(*n_input_table, *n_part_column, number_of_partitions); + auto [partitioned_table, partition_offsets] = + cudf::partition(*n_input_table, *n_part_column, number_of_partitions); - for (size_t i = 0; i < result.second.size() - 1; i++) { - // for what ever reason partition returns the length of the result at then - // end and hash partition/round robin do not, so skip the last entry for - // consistency - n_output_offsets[i] = result.second[i]; - } + // for what ever reason partition returns the length of the result at then + // end and hash partition/round robin do not, so skip the last entry for + // consistency + cudf::jni::native_jintArray n_output_offsets(env, output_offsets); + std::copy(partition_offsets.begin(), partition_offsets.end() - 1, n_output_offsets.begin()); - return cudf::jni::convert_table_for_return(env, result.first); + return convert_table_for_return(env, partitioned_table); } CATCH_STD(env, NULL); } @@ -2792,26 +2763,21 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition( try { cudf::jni::auto_set_device(env); - cudf::hash_id hash_func = static_cast(hash_function); - cudf::table_view *n_input_table = reinterpret_cast(input_table); + auto const hash_func = static_cast(hash_function); + auto const n_input_table = reinterpret_cast(input_table); cudf::jni::native_jintArray n_columns_to_hash(env, columns_to_hash); - cudf::jni::native_jintArray n_output_offsets(env, output_offsets); - JNI_ARG_CHECK(env, n_columns_to_hash.size() > 0, "columns_to_hash is zero", NULL); - std::vector columns_to_hash_vec(n_columns_to_hash.size()); - for (int i = 0; i < n_columns_to_hash.size(); i++) { - columns_to_hash_vec[i] = n_columns_to_hash[i]; - } + std::vector columns_to_hash_vec(n_columns_to_hash.begin(), + n_columns_to_hash.end()); - std::pair, std::vector> result = + auto [partitioned_table, partition_offsets] = cudf::hash_partition(*n_input_table, columns_to_hash_vec, number_of_partitions, hash_func); - for (size_t i = 0; i < result.second.size(); i++) { - n_output_offsets[i] = result.second[i]; - } + cudf::jni::native_jintArray n_output_offsets(env, output_offsets); + std::copy(partition_offsets.begin(), partition_offsets.end(), n_output_offsets.begin()); - return cudf::jni::convert_table_for_return(env, result.first); + return convert_table_for_return(env, partitioned_table); } CATCH_STD(env, NULL); } @@ -2827,15 +2793,14 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_roundRobinPartition( try { cudf::jni::auto_set_device(env); auto n_input_table = reinterpret_cast(input_table); - cudf::jni::native_jintArray n_output_offsets(env, output_offsets); - auto result = cudf::round_robin_partition(*n_input_table, num_partitions, start_partition); + auto [partitioned_table, partition_offsets] = + cudf::round_robin_partition(*n_input_table, num_partitions, start_partition); - for (size_t i = 0; i < result.second.size(); i++) { - n_output_offsets[i] = result.second[i]; - } + cudf::jni::native_jintArray n_output_offsets(env, output_offsets); + std::copy(partition_offsets.begin(), partition_offsets.end(), n_output_offsets.begin()); - return cudf::jni::convert_table_for_return(env, result.first); + return convert_table_for_return(env, partitioned_table); } CATCH_STD(env, NULL); } @@ -2905,7 +2870,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_groupByAggregate( result_columns.push_back(std::move(result.second[agg_result_index].results[col_agg_index])); } } - return cudf::jni::convert_table_for_return(env, result.first, result_columns); + return convert_table_for_return(env, result.first, std::move(result_columns)); } CATCH_STD(env, NULL); } @@ -2975,7 +2940,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_groupByScan( result_columns.push_back(std::move(result.second[agg_result_index].results[col_agg_index])); } } - return cudf::jni::convert_table_for_return(env, result.first, result_columns); + return convert_table_for_return(env, result.first, std::move(result_columns)); } CATCH_STD(env, NULL); } @@ -3020,10 +2985,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_groupByReplaceNulls( std::vector policies = n_is_preceding.transform_if_else( cudf::replace_policy::PRECEDING, cudf::replace_policy::FOLLOWING); - std::pair, std::unique_ptr> result = - grouper.replace_nulls(n_replace_table, policies); - - return cudf::jni::convert_table_for_return(env, result.first, result.second); + auto [keys, results] = grouper.replace_nulls(n_replace_table, policies); + return convert_table_for_return(env, keys, results); } CATCH_STD(env, NULL); } @@ -3034,10 +2997,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_filter(JNIEnv *env, jclas JNI_NULL_CHECK(env, mask_jcol, "mask column is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input = reinterpret_cast(input_jtable); - cudf::column_view *mask = reinterpret_cast(mask_jcol); - std::unique_ptr result = cudf::apply_boolean_mask(*input, *mask); - return cudf::jni::convert_table_for_return(env, result); + auto const input = reinterpret_cast(input_jtable); + auto const mask = reinterpret_cast(mask_jcol); + return convert_table_for_return(env, cudf::apply_boolean_mask(*input, *mask)); } CATCH_STD(env, 0); } @@ -3063,7 +3025,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates( nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL, nulls_before ? cudf::null_order::BEFORE : cudf::null_order::AFTER, rmm::mr::get_current_device_resource()); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, result); } CATCH_STD(env, 0); } @@ -3074,12 +3036,11 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_gather(JNIEnv *env, jclas JNI_NULL_CHECK(env, j_map, "map column is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input = reinterpret_cast(j_input); - cudf::column_view *map = reinterpret_cast(j_map); + auto const input = reinterpret_cast(j_input); + auto const map = reinterpret_cast(j_map); auto bounds_policy = check_bounds ? cudf::out_of_bounds_policy::NULLIFY : cudf::out_of_bounds_policy::DONT_CHECK; - std::unique_ptr result = cudf::gather(*input, *map, bounds_policy); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, cudf::gather(*input, *map, bounds_policy)); } CATCH_STD(env, 0); } @@ -3090,7 +3051,7 @@ Java_ai_rapids_cudf_Table_convertToRowsFixedWidthOptimized(JNIEnv *env, jclass, try { cudf::jni::auto_set_device(env); - cudf::table_view *n_input_table = reinterpret_cast(input_table); + auto const n_input_table = reinterpret_cast(input_table); std::vector> cols = cudf::jni::convert_to_rows_fixed_width_optimized(*n_input_table); int num_columns = cols.size(); @@ -3114,8 +3075,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_scatterTable(JNIEnv *env, auto const input = reinterpret_cast(j_input); auto const map = reinterpret_cast(j_map); auto const target = reinterpret_cast(j_target); - auto result = cudf::scatter(*input, *map, *target, check_bounds); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, cudf::scatter(*input, *map, *target, check_bounds)); } CATCH_STD(env, 0); } @@ -3131,13 +3091,11 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_scatterScalars(JNIEnv *en cudf::jni::auto_set_device(env); auto const scalars_array = cudf::jni::native_jpointerArray(env, j_input); std::vector> input; - for (int i = 0; i < scalars_array.size(); ++i) { - input.emplace_back(*scalars_array[i]); - } + std::transform(scalars_array.begin(), scalars_array.end(), std::back_inserter(input), + [](auto &scalar) { return std::ref(*scalar); }); auto const map = reinterpret_cast(j_map); auto const target = reinterpret_cast(j_target); - auto result = cudf::scatter(input, *map, *target, check_bounds); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, cudf::scatter(input, *map, *target, check_bounds)); } CATCH_STD(env, 0); } @@ -3148,7 +3106,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertToRows(JNIEnv *env try { cudf::jni::auto_set_device(env); - cudf::table_view *n_input_table = reinterpret_cast(input_table); + auto const n_input_table = reinterpret_cast(input_table); std::vector> cols = cudf::jni::convert_to_rows(*n_input_table); int num_columns = cols.size(); cudf::jni::native_jlongArray outcol_handles(env, num_columns); @@ -3166,8 +3124,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertFromRowsFixedWidth try { cudf::jni::auto_set_device(env); - cudf::column_view *input = reinterpret_cast(input_column); - cudf::lists_column_view list_input(*input); + cudf::lists_column_view const list_input{*reinterpret_cast(input_column)}; cudf::jni::native_jintArray n_types(env, types); cudf::jni::native_jintArray n_scale(env, scale); if (n_types.size() != n_scale.size()) { @@ -3179,7 +3136,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertFromRowsFixedWidth [](jint type, jint scale) { return cudf::jni::make_data_type(type, scale); }); std::unique_ptr result = cudf::jni::convert_from_rows_fixed_width_optimized(list_input, types_vec); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, result); } CATCH_STD(env, 0); } @@ -3193,8 +3150,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertFromRows(JNIEnv *e try { cudf::jni::auto_set_device(env); - cudf::column_view *input = reinterpret_cast(input_column); - cudf::lists_column_view list_input(*input); + cudf::lists_column_view const list_input{*reinterpret_cast(input_column)}; cudf::jni::native_jintArray n_types(env, types); cudf::jni::native_jintArray n_scale(env, scale); if (n_types.size() != n_scale.size()) { @@ -3205,7 +3161,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertFromRows(JNIEnv *e std::transform(n_types.begin(), n_types.end(), n_scale.begin(), std::back_inserter(types_vec), [](jint type, jint scale) { return cudf::jni::make_data_type(type, scale); }); std::unique_ptr result = cudf::jni::convert_from_rows(list_input, types_vec); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, result); } CATCH_STD(env, 0); } @@ -3216,9 +3172,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_repeatStaticCount(JNIEnv JNI_NULL_CHECK(env, input_jtable, "input table is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input = reinterpret_cast(input_jtable); - std::unique_ptr result = cudf::repeat(*input, count); - return cudf::jni::convert_table_for_return(env, result); + auto const input = reinterpret_cast(input_jtable); + return convert_table_for_return(env, cudf::repeat(*input, count)); } CATCH_STD(env, 0); } @@ -3231,10 +3186,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_repeatColumnCount(JNIEnv JNI_NULL_CHECK(env, count_jcol, "count column is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input = reinterpret_cast(input_jtable); - cudf::column_view *count = reinterpret_cast(count_jcol); - std::unique_ptr result = cudf::repeat(*input, *count, check_count); - return cudf::jni::convert_table_for_return(env, result); + auto const input = reinterpret_cast(input_jtable); + auto const count = reinterpret_cast(count_jcol); + return convert_table_for_return(env, cudf::repeat(*input, *count, check_count)); } CATCH_STD(env, 0); } @@ -3351,7 +3305,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rollingWindowAggregate( } auto result_table = std::make_unique(std::move(result_columns)); - return cudf::jni::convert_table_for_return(env, result_table); + return convert_table_for_return(env, result_table); } CATCH_STD(env, NULL); } @@ -3444,7 +3398,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega } auto result_table = std::make_unique(std::move(result_columns)); - return cudf::jni::convert_table_for_return(env, result_table); + return convert_table_for_return(env, result_table); } CATCH_STD(env, NULL); } @@ -3455,10 +3409,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explode(JNIEnv *env, jcla JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input_table = reinterpret_cast(input_jtable); - cudf::size_type col_index = static_cast(column_index); - std::unique_ptr exploded = cudf::explode(*input_table, col_index); - return cudf::jni::convert_table_for_return(env, exploded); + auto const input_table = reinterpret_cast(input_jtable); + auto const col_index = static_cast(column_index); + return convert_table_for_return(env, cudf::explode(*input_table, col_index)); } CATCH_STD(env, 0); } @@ -3469,10 +3422,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodePosition(JNIEnv *e JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input_table = reinterpret_cast(input_jtable); - cudf::size_type col_index = static_cast(column_index); - std::unique_ptr exploded = cudf::explode_position(*input_table, col_index); - return cudf::jni::convert_table_for_return(env, exploded); + auto const input_table = reinterpret_cast(input_jtable); + auto const col_index = static_cast(column_index); + return convert_table_for_return(env, cudf::explode_position(*input_table, col_index)); } CATCH_STD(env, 0); } @@ -3483,10 +3435,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuter(JNIEnv *env, JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input_table = reinterpret_cast(input_jtable); - cudf::size_type col_index = static_cast(column_index); - std::unique_ptr exploded = cudf::explode_outer(*input_table, col_index); - return cudf::jni::convert_table_for_return(env, exploded); + auto const input_table = reinterpret_cast(input_jtable); + auto const col_index = static_cast(column_index); + return convert_table_for_return(env, cudf::explode_outer(*input_table, col_index)); } CATCH_STD(env, 0); } @@ -3497,10 +3448,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuterPosition(JNIE JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input_table = reinterpret_cast(input_jtable); - cudf::size_type col_index = static_cast(column_index); - std::unique_ptr exploded = cudf::explode_outer_position(*input_table, col_index); - return cudf::jni::convert_table_for_return(env, exploded); + auto const input_table = reinterpret_cast(input_jtable); + auto const col_index = static_cast(column_index); + return convert_table_for_return(env, cudf::explode_outer_position(*input_table, col_index)); } CATCH_STD(env, 0); } @@ -3509,8 +3459,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_rowBitCount(JNIEnv *env, jclas JNI_NULL_CHECK(env, j_table, "table is null", 0); try { cudf::jni::auto_set_device(env); - auto t = reinterpret_cast(j_table); - return release_as_jlong(cudf::row_bit_count(*t)); + auto const input_table = reinterpret_cast(j_table); + return release_as_jlong(cudf::row_bit_count(*input_table)); } CATCH_STD(env, 0); } @@ -3528,7 +3478,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplitGroups( try { cudf::jni::auto_set_device(env); cudf::jni::native_jintArray n_key_indices(env, jkey_indices); - cudf::table_view *input_table = reinterpret_cast(jinput_table); + auto const input_table = reinterpret_cast(jinput_table); // Prepares arguments for the groupby: // (keys, null_handling, keys_are_sorted, column_order, null_precedence) @@ -3622,11 +3572,10 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_sample(JNIEnv *env, jclas JNI_NULL_CHECK(env, j_input, "input table is null", 0); try { cudf::jni::auto_set_device(env); - cudf::table_view *input = reinterpret_cast(j_input); + auto const input = reinterpret_cast(j_input); auto sample_with_replacement = replacement ? cudf::sample_with_replacement::TRUE : cudf::sample_with_replacement::FALSE; - std::unique_ptr result = cudf::sample(*input, n, sample_with_replacement, seed); - return cudf::jni::convert_table_for_return(env, result); + return convert_table_for_return(env, cudf::sample(*input, n, sample_with_replacement, seed)); } CATCH_STD(env, 0); } diff --git a/java/src/main/native/src/cudf_jni_apis.hpp b/java/src/main/native/src/cudf_jni_apis.hpp index fbcca0c82ee..12fd45b831a 100644 --- a/java/src/main/native/src/cudf_jni_apis.hpp +++ b/java/src/main/native/src/cudf_jni_apis.hpp @@ -23,7 +23,28 @@ namespace cudf { namespace jni { -jlongArray convert_table_for_return(JNIEnv *env, std::unique_ptr &table_result); +/** + * @brief Detach all columns from the specified table, and pointers to them as an array. + * + * This function takes a table (presumably returned by some operation), and turns it into an + * array of column* (as jlongs). + * The lifetime of the columns is decoupled from that of the table, and is managed by the caller. + * + * @param env The JNI environment + * @param table_result the table to convert for return + * @param extra_columns columns not in the table that will be appended to the result. + */ +jlongArray +convert_table_for_return(JNIEnv *env, std::unique_ptr &table_result, + std::vector> &&extra_columns = {}); + +/** + * @copydoc convert_table_for_return(JNIEnv*, std::unique_ptr&, + * std::vector>&&) + */ +jlongArray +convert_table_for_return(JNIEnv *env, std::unique_ptr &&table_result, + std::vector> &&extra_columns = {}); // // ContiguousTable APIs