diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 6f67cb32b0a..c5ae3345da5 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -169,7 +169,10 @@ ConfigureNVBench(SEARCH_NVBENCH search/contains.cpp) # ################################################################################################## # * sort benchmark -------------------------------------------------------------------------------- ConfigureBench(SORT_BENCH sort/rank.cpp sort/sort.cpp sort/sort_strings.cpp) -ConfigureNVBench(SORT_NVBENCH sort/segmented_sort.cpp sort/sort_lists.cpp sort/sort_structs.cpp) +ConfigureNVBench( + SORT_NVBENCH sort/rank_lists.cpp sort/rank_structs.cpp sort/segmented_sort.cpp + sort/sort_lists.cpp sort/sort_structs.cpp +) # ################################################################################################## # * quantiles benchmark diff --git a/cpp/benchmarks/sort/nested_types_common.hpp b/cpp/benchmarks/sort/nested_types_common.hpp new file mode 100644 index 00000000000..c4851823534 --- /dev/null +++ b/cpp/benchmarks/sort/nested_types_common.hpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include + +#include + +inline std::unique_ptr create_lists_data(nvbench::state& state) +{ + const size_t size_bytes(state.get_int64("size_bytes")); + const cudf::size_type depth{static_cast(state.get_int64("depth"))}; + auto const null_frequency{state.get_float64("null_frequency")}; + + data_profile table_profile; + table_profile.set_distribution_params(cudf::type_id::LIST, distribution_id::UNIFORM, 0, 5); + table_profile.set_list_depth(depth); + table_profile.set_null_probability(null_frequency); + return create_random_table({cudf::type_id::LIST}, table_size_bytes{size_bytes}, table_profile); +} + +inline std::unique_ptr create_structs_data(nvbench::state& state, + cudf::size_type const n_cols = 1) +{ + using Type = int; + using column_wrapper = cudf::test::fixed_width_column_wrapper; + std::default_random_engine generator; + std::uniform_int_distribution distribution(0, 100); + + const cudf::size_type n_rows{static_cast(state.get_int64("NumRows"))}; + const cudf::size_type depth{static_cast(state.get_int64("Depth"))}; + const bool nulls{static_cast(state.get_int64("Nulls"))}; + + // Create columns with values in the range [0,100) + std::vector columns; + columns.reserve(n_cols); + std::generate_n(std::back_inserter(columns), n_cols, [&]() { + auto const elements = cudf::detail::make_counting_transform_iterator( + 0, [&](auto row) { return distribution(generator); }); + if (!nulls) return column_wrapper(elements, elements + n_rows); + auto valids = + cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 10 != 0; }); + return column_wrapper(elements, elements + n_rows, valids); + }); + + std::vector> cols; + std::transform(columns.begin(), columns.end(), std::back_inserter(cols), [](column_wrapper& col) { + return col.release(); + }); + + std::vector> child_cols = std::move(cols); + // Nest the child columns in a struct, then nest that struct column inside another + // struct column up to the desired depth + for (int i = 0; i < depth; i++) { + std::vector struct_validity; + std::uniform_int_distribution bool_distribution(0, 100 * (i + 1)); + std::generate_n( + std::back_inserter(struct_validity), n_rows, [&]() { return bool_distribution(generator); }); + cudf::test::structs_column_wrapper struct_col(std::move(child_cols), struct_validity); + child_cols = std::vector>{}; + child_cols.push_back(struct_col.release()); + } + + // Create table view + return std::make_unique(std::move(child_cols)); +} diff --git a/cpp/benchmarks/sort/rank.cpp b/cpp/benchmarks/sort/rank.cpp index 2c26f4fa15d..6d0a8e5aedd 100644 --- a/cpp/benchmarks/sort/rank.cpp +++ b/cpp/benchmarks/sort/rank.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ static void BM_rank(benchmark::State& state, bool nulls) // Create columns with values in the range [0,100) data_profile profile = data_profile_builder().cardinality(0).distribution( cudf::type_to_id(), distribution_id::UNIFORM, 0, 100); - profile.set_null_probability(nulls ? std::optional{0.01} : std::nullopt); + profile.set_null_probability(nulls ? std::optional{0.2} : std::nullopt); auto keys = create_random_column(cudf::type_to_id(), row_count{n_rows}, profile); for (auto _ : state) { diff --git a/cpp/benchmarks/sort/rank_lists.cpp b/cpp/benchmarks/sort/rank_lists.cpp new file mode 100644 index 00000000000..f467b639810 --- /dev/null +++ b/cpp/benchmarks/sort/rank_lists.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nested_types_common.hpp" +#include "rank_types_common.hpp" + +#include + +#include + +#include + +template +void nvbench_rank_lists(nvbench::state& state, nvbench::type_list>) +{ + cudf::rmm_pool_raii pool_raii; + + auto const table = create_lists_data(state); + + auto const null_frequency{state.get_float64("null_frequency")}; + + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { + cudf::rank(table->view().column(0), + method, + cudf::order::ASCENDING, + null_frequency ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE, + cudf::null_order::AFTER, + rmm::mr::get_current_device_resource()); + }); +} + +NVBENCH_BENCH_TYPES(nvbench_rank_lists, NVBENCH_TYPE_AXES(methods)) + .set_name("rank_lists") + .add_int64_power_of_two_axis("size_bytes", {10, 18, 24, 28}) + .add_int64_axis("depth", {1, 4}) + .add_float64_axis("null_frequency", {0, 0.2}); diff --git a/cpp/benchmarks/sort/rank_structs.cpp b/cpp/benchmarks/sort/rank_structs.cpp new file mode 100644 index 00000000000..c1e2c5bd7dc --- /dev/null +++ b/cpp/benchmarks/sort/rank_structs.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nested_types_common.hpp" +#include "rank_types_common.hpp" + +#include + +#include + +template +void nvbench_rank_structs(nvbench::state& state, nvbench::type_list>) +{ + cudf::rmm_pool_raii pool_raii; + + auto const table = create_structs_data(state); + + const bool nulls{static_cast(state.get_int64("Nulls"))}; + + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { + cudf::rank(table->view().column(0), + method, + cudf::order::ASCENDING, + nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE, + cudf::null_order::AFTER, + rmm::mr::get_current_device_resource()); + }); +} + +NVBENCH_BENCH_TYPES(nvbench_rank_structs, NVBENCH_TYPE_AXES(methods)) + .set_name("rank_structs") + .add_int64_power_of_two_axis("NumRows", {10, 18, 26}) + .add_int64_axis("Depth", {0, 1, 8}) + .add_int64_axis("Nulls", {0, 1}); diff --git a/cpp/benchmarks/sort/rank_types_common.hpp b/cpp/benchmarks/sort/rank_types_common.hpp new file mode 100644 index 00000000000..adb58606c42 --- /dev/null +++ b/cpp/benchmarks/sort/rank_types_common.hpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +enum class rank_method : int32_t {}; + +NVBENCH_DECLARE_ENUM_TYPE_STRINGS( + cudf::rank_method, + [](cudf::rank_method value) { + switch (value) { + case cudf::rank_method::FIRST: return "FIRST"; + case cudf::rank_method::AVERAGE: return "AVERAGE"; + case cudf::rank_method::MIN: return "MIN"; + case cudf::rank_method::MAX: return "MAX"; + case cudf::rank_method::DENSE: return "DENSE"; + default: return "unknown"; + } + }, + [](cudf::rank_method value) { + switch (value) { + case cudf::rank_method::FIRST: return "cudf::rank_method::FIRST"; + case cudf::rank_method::AVERAGE: return "cudf::rank_method::AVERAGE"; + case cudf::rank_method::MIN: return "cudf::rank_method::MIN"; + case cudf::rank_method::MAX: return "cudf::rank_method::MAX"; + case cudf::rank_method::DENSE: return "cudf::rank_method::DENSE"; + default: return "unknown"; + } + }) + +using methods = nvbench::enum_type_list; diff --git a/cpp/benchmarks/sort/sort_lists.cpp b/cpp/benchmarks/sort/sort_lists.cpp index d2124874e7a..b55b60f5ec9 100644 --- a/cpp/benchmarks/sort/sort_lists.cpp +++ b/cpp/benchmarks/sort/sort_lists.cpp @@ -14,8 +14,7 @@ * limitations under the License. */ -#include -#include +#include "nested_types_common.hpp" #include @@ -23,16 +22,7 @@ void nvbench_sort_lists(nvbench::state& state) { - const size_t size_bytes(state.get_int64("size_bytes")); - const cudf::size_type depth{static_cast(state.get_int64("depth"))}; - auto const null_frequency{state.get_float64("null_frequency")}; - - data_profile table_profile; - table_profile.set_distribution_params(cudf::type_id::LIST, distribution_id::UNIFORM, 0, 5); - table_profile.set_list_depth(depth); - table_profile.set_null_probability(null_frequency); - auto const table = - create_random_table({cudf::type_id::LIST}, table_size_bytes{size_bytes}, table_profile); + auto const table = create_lists_data(state); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { rmm::cuda_stream_view stream_view{launch.get_stream()}; diff --git a/cpp/benchmarks/sort/sort_structs.cpp b/cpp/benchmarks/sort/sort_structs.cpp index 02af37e47f1..1d54fa42f6f 100644 --- a/cpp/benchmarks/sort/sort_structs.cpp +++ b/cpp/benchmarks/sort/sort_structs.cpp @@ -14,63 +14,19 @@ * limitations under the License. */ -#include - -#include +#include "nested_types_common.hpp" #include #include -#include - void nvbench_sort_struct(nvbench::state& state) { - using Type = int; - using column_wrapper = cudf::test::fixed_width_column_wrapper; - std::default_random_engine generator; - std::uniform_int_distribution distribution(0, 100); - - const cudf::size_type n_rows{static_cast(state.get_int64("NumRows"))}; - const cudf::size_type n_cols{1}; - const cudf::size_type depth{static_cast(state.get_int64("Depth"))}; - const bool nulls{static_cast(state.get_int64("Nulls"))}; - - // Create columns with values in the range [0,100) - std::vector columns; - columns.reserve(n_cols); - std::generate_n(std::back_inserter(columns), n_cols, [&]() { - auto const elements = cudf::detail::make_counting_transform_iterator( - 0, [&](auto row) { return distribution(generator); }); - if (!nulls) return column_wrapper(elements, elements + n_rows); - auto valids = - cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i % 10 != 0; }); - return column_wrapper(elements, elements + n_rows, valids); - }); - - std::vector> cols; - std::transform(columns.begin(), columns.end(), std::back_inserter(cols), [](column_wrapper& col) { - return col.release(); - }); - - std::vector> child_cols = std::move(cols); - // Lets add some layers - for (int i = 0; i < depth; i++) { - std::vector struct_validity; - std::uniform_int_distribution bool_distribution(0, 100 * (i + 1)); - std::generate_n( - std::back_inserter(struct_validity), n_rows, [&]() { return bool_distribution(generator); }); - cudf::test::structs_column_wrapper struct_col(std::move(child_cols), struct_validity); - child_cols = std::vector>{}; - child_cols.push_back(struct_col.release()); - } - - // Create table view - auto const input = cudf::table(std::move(child_cols)); + auto const input = create_structs_data(state); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) { rmm::cuda_stream_view stream_view{launch.get_stream()}; - cudf::detail::sorted_order(input, {}, {}, stream_view, rmm::mr::get_current_device_resource()); + cudf::detail::sorted_order(*input, {}, {}, stream_view, rmm::mr::get_current_device_resource()); }); } diff --git a/cpp/src/sort/rank.cu b/cpp/src/sort/rank.cu index 99e99704c10..461e978643f 100644 --- a/cpp/src/sort/rank.cu +++ b/cpp/src/sort/rank.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -47,37 +47,26 @@ namespace cudf { namespace detail { namespace { -// Functor to identify unique elements in a sorted order table/column -template -struct unique_comparator { - unique_comparator(table_device_view device_table, Iterator const sorted_order, bool has_nulls) - : comparator(nullate::DYNAMIC{has_nulls}, device_table, device_table, null_equality::EQUAL), - permute(sorted_order) - { - } - __device__ ReturnType operator()(size_type index) const noexcept - { - return index == 0 || not comparator(permute[index], permute[index - 1]); - }; - - private: - row_equality_comparator comparator; - Iterator const permute; -}; // Assign rank from 1 to n unique values. Equal values get same rank value. rmm::device_uvector sorted_dense_rank(column_view input_col, column_view sorted_order_view, rmm::cuda_stream_view stream) { - auto device_table = table_device_view::create(table_view{{input_col}}, stream); + auto const t_input = table_view{{input_col}}; + auto const comparator = cudf::experimental::row::equality::self_comparator{t_input, stream}; + auto const device_comparator = comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(t_input)}); + + auto const sorted_index_order = thrust::make_permutation_iterator( + sorted_order_view.begin(), thrust::make_counting_iterator(0)); + auto conv = [permute = sorted_index_order, device_comparator] __device__(size_type index) { + return static_cast(index == 0 || + not device_comparator(permute[index], permute[index - 1])); + }; + auto const unique_it = cudf::detail::make_counting_transform_iterator(0, conv); + auto const input_size = input_col.size(); rmm::device_uvector dense_rank_sorted(input_size, stream); - auto sorted_index_order = thrust::make_permutation_iterator( - sorted_order_view.begin(), thrust::make_counting_iterator(0)); - auto conv = unique_comparator( - *device_table, sorted_index_order, input_col.has_nulls()); - auto unique_it = cudf::detail::make_counting_transform_iterator(0, conv); thrust::inclusive_scan( rmm::exec_policy(stream), unique_it, unique_it + input_size, dense_rank_sorted.data()); diff --git a/cpp/tests/sort/rank_test.cpp b/cpp/tests/sort/rank_test.cpp index 8461b0a1984..2722c1dfdad 100644 --- a/cpp/tests/sort/rank_test.cpp +++ b/cpp/tests/sort/rank_test.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -30,6 +31,13 @@ #include #include +template +using lists_col = cudf::test::lists_column_wrapper; +using structs_col = cudf::test::structs_column_wrapper; + +using cudf::test::iterators::null_at; +using cudf::test::iterators::nulls_at; + namespace { void run_rank_test(cudf::table_view input, cudf::table_view expected, @@ -50,10 +58,9 @@ void run_rank_test(cudf::table_view input, } using input_arg_t = std::tuple; -input_arg_t asce_keep{cudf::order::ASCENDING, cudf::null_policy::EXCLUDE, cudf::null_order::AFTER}; -input_arg_t asce_top{cudf::order::ASCENDING, cudf::null_policy::INCLUDE, cudf::null_order::BEFORE}; -input_arg_t asce_bottom{ - cudf::order::ASCENDING, cudf::null_policy::INCLUDE, cudf::null_order::AFTER}; +input_arg_t asc_keep{cudf::order::ASCENDING, cudf::null_policy::EXCLUDE, cudf::null_order::AFTER}; +input_arg_t asc_top{cudf::order::ASCENDING, cudf::null_policy::INCLUDE, cudf::null_order::BEFORE}; +input_arg_t asc_bottom{cudf::order::ASCENDING, cudf::null_policy::INCLUDE, cudf::null_order::AFTER}; input_arg_t desc_keep{ cudf::order::DESCENDING, cudf::null_policy::EXCLUDE, cudf::null_order::BEFORE}; @@ -105,7 +112,7 @@ TYPED_TEST_SUITE(Rank, cudf::test::NumericTypes); // fixed_width_column_wrapper col1{{ 5, 4, 3, 5, 8, 5}}; // 3, 2, 1, 4, 6, 5 -TYPED_TEST(Rank, first_asce_keep) +TYPED_TEST(Rank, first_asc_keep) { // ASCENDING cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 4, 6, 5}}; @@ -113,25 +120,25 @@ TYPED_TEST(Rank, first_asce_keep) {1, 1, 0, 1, 1, 1}}; // KEEP cudf::test::fixed_width_column_wrapper col3_rank{{2, 5, 1, 3, 6, 4}, {1, 1, 1, 1, 1, 1}}; - this->run_all_tests(cudf::rank_method::FIRST, asce_keep, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::FIRST, asc_keep, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, first_asce_top) +TYPED_TEST(Rank, first_asc_top) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 4, 6, 5}}; cudf::test::fixed_width_column_wrapper col2_rank{ {3, 2, 1, 4, 6, 5}}; // BEFORE = TOP cudf::test::fixed_width_column_wrapper col3_rank{{2, 5, 1, 3, 6, 4}}; - this->run_all_tests(cudf::rank_method::FIRST, asce_top, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::FIRST, asc_top, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, first_asce_bottom) +TYPED_TEST(Rank, first_asc_bottom) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 4, 6, 5}}; cudf::test::fixed_width_column_wrapper col2_rank{ {2, 1, 6, 3, 5, 4}}; // AFTER = BOTTOM cudf::test::fixed_width_column_wrapper col3_rank{{2, 5, 1, 3, 6, 4}}; - this->run_all_tests(cudf::rank_method::FIRST, asce_bottom, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::FIRST, asc_bottom, col1_rank, col2_rank, col3_rank); } TYPED_TEST(Rank, first_desc_keep) @@ -163,30 +170,30 @@ TYPED_TEST(Rank, first_desc_bottom) this->run_all_tests(cudf::rank_method::FIRST, desc_bottom, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, dense_asce_keep) +TYPED_TEST(Rank, dense_asc_keep) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 3, 4, 3}}; cudf::test::fixed_width_column_wrapper col2_rank{{2, 1, -1, 2, 3, 2}, {1, 1, 0, 1, 1, 1}}; cudf::test::fixed_width_column_wrapper col3_rank{{2, 3, 1, 2, 4, 2}, {1, 1, 1, 1, 1, 1}}; - this->run_all_tests(cudf::rank_method::DENSE, asce_keep, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::DENSE, asc_keep, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, dense_asce_top) +TYPED_TEST(Rank, dense_asc_top) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 3, 4, 3}}; cudf::test::fixed_width_column_wrapper col2_rank{{3, 2, 1, 3, 4, 3}}; cudf::test::fixed_width_column_wrapper col3_rank{{2, 3, 1, 2, 4, 2}}; - this->run_all_tests(cudf::rank_method::DENSE, asce_top, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::DENSE, asc_top, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, dense_asce_bottom) +TYPED_TEST(Rank, dense_asc_bottom) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 3, 4, 3}}; cudf::test::fixed_width_column_wrapper col2_rank{{2, 1, 4, 2, 3, 2}}; cudf::test::fixed_width_column_wrapper col3_rank{{2, 3, 1, 2, 4, 2}}; - this->run_all_tests(cudf::rank_method::DENSE, asce_bottom, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::DENSE, asc_bottom, col1_rank, col2_rank, col3_rank); } TYPED_TEST(Rank, dense_desc_keep) @@ -215,30 +222,30 @@ TYPED_TEST(Rank, dense_desc_bottom) this->run_all_tests(cudf::rank_method::DENSE, desc_bottom, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, min_asce_keep) +TYPED_TEST(Rank, min_asc_keep) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 3, 6, 3}}; cudf::test::fixed_width_column_wrapper col2_rank{{2, 1, -1, 2, 5, 2}, {1, 1, 0, 1, 1, 1}}; cudf::test::fixed_width_column_wrapper col3_rank{{2, 5, 1, 2, 6, 2}, {1, 1, 1, 1, 1, 1}}; - this->run_all_tests(cudf::rank_method::MIN, asce_keep, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::MIN, asc_keep, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, min_asce_top) +TYPED_TEST(Rank, min_asc_top) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 3, 6, 3}}; cudf::test::fixed_width_column_wrapper col2_rank{{3, 2, 1, 3, 6, 3}}; cudf::test::fixed_width_column_wrapper col3_rank{{2, 5, 1, 2, 6, 2}}; - this->run_all_tests(cudf::rank_method::MIN, asce_top, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::MIN, asc_top, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, min_asce_bottom) +TYPED_TEST(Rank, min_asc_bottom) { cudf::test::fixed_width_column_wrapper col1_rank{{3, 2, 1, 3, 6, 3}}; cudf::test::fixed_width_column_wrapper col2_rank{{2, 1, 6, 2, 5, 2}}; cudf::test::fixed_width_column_wrapper col3_rank{{2, 5, 1, 2, 6, 2}}; - this->run_all_tests(cudf::rank_method::MIN, asce_bottom, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::MIN, asc_bottom, col1_rank, col2_rank, col3_rank); } TYPED_TEST(Rank, min_desc_keep) @@ -267,30 +274,30 @@ TYPED_TEST(Rank, min_desc_bottom) this->run_all_tests(cudf::rank_method::MIN, desc_bottom, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, max_asce_keep) +TYPED_TEST(Rank, max_asc_keep) { cudf::test::fixed_width_column_wrapper col1_rank{{5, 2, 1, 5, 6, 5}}; cudf::test::fixed_width_column_wrapper col2_rank{{4, 1, -1, 4, 5, 4}, {1, 1, 0, 1, 1, 1}}; cudf::test::fixed_width_column_wrapper col3_rank{{4, 5, 1, 4, 6, 4}, {1, 1, 1, 1, 1, 1}}; - this->run_all_tests(cudf::rank_method::MAX, asce_keep, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::MAX, asc_keep, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, max_asce_top) +TYPED_TEST(Rank, max_asc_top) { cudf::test::fixed_width_column_wrapper col1_rank{{5, 2, 1, 5, 6, 5}}; cudf::test::fixed_width_column_wrapper col2_rank{{5, 2, 1, 5, 6, 5}}; cudf::test::fixed_width_column_wrapper col3_rank{{4, 5, 1, 4, 6, 4}}; - this->run_all_tests(cudf::rank_method::MAX, asce_top, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::MAX, asc_top, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, max_asce_bottom) +TYPED_TEST(Rank, max_asc_bottom) { cudf::test::fixed_width_column_wrapper col1_rank{{5, 2, 1, 5, 6, 5}}; cudf::test::fixed_width_column_wrapper col2_rank{{4, 1, 6, 4, 5, 4}}; cudf::test::fixed_width_column_wrapper col3_rank{{4, 5, 1, 4, 6, 4}}; - this->run_all_tests(cudf::rank_method::MAX, asce_bottom, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::MAX, asc_bottom, col1_rank, col2_rank, col3_rank); } TYPED_TEST(Rank, max_desc_keep) @@ -319,28 +326,28 @@ TYPED_TEST(Rank, max_desc_bottom) this->run_all_tests(cudf::rank_method::MAX, desc_bottom, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, average_asce_keep) +TYPED_TEST(Rank, average_asc_keep) { cudf::test::fixed_width_column_wrapper col1_rank{{4, 2, 1, 4, 6, 4}}; cudf::test::fixed_width_column_wrapper col2_rank{{3, 1, -1, 3, 5, 3}, {1, 1, 0, 1, 1, 1}}; cudf::test::fixed_width_column_wrapper col3_rank{{3, 5, 1, 3, 6, 3}, {1, 1, 1, 1, 1, 1}}; - this->run_all_tests(cudf::rank_method::AVERAGE, asce_keep, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::AVERAGE, asc_keep, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, average_asce_top) +TYPED_TEST(Rank, average_asc_top) { cudf::test::fixed_width_column_wrapper col1_rank{{4, 2, 1, 4, 6, 4}}; cudf::test::fixed_width_column_wrapper col2_rank{{4, 2, 1, 4, 6, 4}}; cudf::test::fixed_width_column_wrapper col3_rank{{3, 5, 1, 3, 6, 3}}; - this->run_all_tests(cudf::rank_method::AVERAGE, asce_top, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::AVERAGE, asc_top, col1_rank, col2_rank, col3_rank); } -TYPED_TEST(Rank, average_asce_bottom) +TYPED_TEST(Rank, average_asc_bottom) { cudf::test::fixed_width_column_wrapper col1_rank{{4, 2, 1, 4, 6, 4}}; cudf::test::fixed_width_column_wrapper col2_rank{{3, 1, 6, 3, 5, 3}}; cudf::test::fixed_width_column_wrapper col3_rank{{3, 5, 1, 3, 6, 3}}; - this->run_all_tests(cudf::rank_method::AVERAGE, asce_bottom, col1_rank, col2_rank, col3_rank); + this->run_all_tests(cudf::rank_method::AVERAGE, asc_bottom, col1_rank, col2_rank, col3_rank); } TYPED_TEST(Rank, average_desc_keep) @@ -368,30 +375,30 @@ TYPED_TEST(Rank, average_desc_bottom) } // percentage==true (dense, not-dense) -TYPED_TEST(Rank, dense_asce_keep_pct) +TYPED_TEST(Rank, dense_asc_keep_pct) { cudf::test::fixed_width_column_wrapper col1_rank{{0.75, 0.5, 0.25, 0.75, 1., 0.75}}; cudf::test::fixed_width_column_wrapper col2_rank{ {2.0 / 3.0, 1.0 / 3.0, -1., 2.0 / 3.0, 1., 2.0 / 3.0}, {1, 1, 0, 1, 1, 1}}; cudf::test::fixed_width_column_wrapper col3_rank{{0.5, 0.75, 0.25, 0.5, 1., 0.5}, {1, 1, 1, 1, 1, 1}}; - this->run_all_tests(cudf::rank_method::DENSE, asce_keep, col1_rank, col2_rank, col3_rank, true); + this->run_all_tests(cudf::rank_method::DENSE, asc_keep, col1_rank, col2_rank, col3_rank, true); } -TYPED_TEST(Rank, dense_asce_top_pct) +TYPED_TEST(Rank, dense_asc_top_pct) { cudf::test::fixed_width_column_wrapper col1_rank{{0.75, 0.5, 0.25, 0.75, 1., 0.75}}; cudf::test::fixed_width_column_wrapper col2_rank{{0.75, 0.5, 0.25, 0.75, 1., 0.75}}; cudf::test::fixed_width_column_wrapper col3_rank{{0.5, 0.75, 0.25, 0.5, 1., 0.5}}; - this->run_all_tests(cudf::rank_method::DENSE, asce_top, col1_rank, col2_rank, col3_rank, true); + this->run_all_tests(cudf::rank_method::DENSE, asc_top, col1_rank, col2_rank, col3_rank, true); } -TYPED_TEST(Rank, dense_asce_bottom_pct) +TYPED_TEST(Rank, dense_asc_bottom_pct) { cudf::test::fixed_width_column_wrapper col1_rank{{0.75, 0.5, 0.25, 0.75, 1., 0.75}}; cudf::test::fixed_width_column_wrapper col2_rank{{0.5, 0.25, 1., 0.5, 0.75, 0.5}}; cudf::test::fixed_width_column_wrapper col3_rank{{0.5, 0.75, 0.25, 0.5, 1., 0.5}}; - this->run_all_tests(cudf::rank_method::DENSE, asce_bottom, col1_rank, col2_rank, col3_rank, true); + this->run_all_tests(cudf::rank_method::DENSE, asc_bottom, col1_rank, col2_rank, col3_rank, true); } TYPED_TEST(Rank, min_desc_keep_pct) @@ -444,3 +451,472 @@ TEST_F(RankLarge, average_large) cudf::test::fixed_width_column_wrapper expected(iter + 1, iter + 10559); CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected); } + +template +struct RankListAndStruct : public cudf::test::BaseFixture { + void run_all_tests(cudf::rank_method method, + input_arg_t input_arg, + cudf::column_view const list_rank, + cudf::column_view const struct_rank, + bool percentage = false) + { + if constexpr (std::is_same_v) { return; } + /* + [ + [], + [1], + [2, 2], + [2, 3], + [2, 2], + [1], + [], + NULL + [2], + NULL, + [1] + ] + */ + auto list_col = + lists_col{{{}, {1}, {2, 2}, {2, 3}, {2, 2}, {1}, {}, {} /*NULL*/, {2}, {} /*NULL*/, {1}}, + nulls_at({7, 9})}; + + // clang-format off + /* + +------------+ + | s| + +------------+ + 0 | {0, null}| + 1 | {1, null}| + 2 | null| + 3 |{null, null}| + 4 | null| + 5 |{null, null}| + 6 | {null, 1}| + 7 | {null, 0}| + +------------+ + */ + std::vector struct_valids{1, 1, 0, 1, 0, 1, 1, 1}; + auto col1 = cudf::test::fixed_width_column_wrapper{{ 0, 1, 9, -1, 9, -1, -1, -1}, {1, 1, 1, 0, 1, 0, 0, 0}}; + auto col2 = cudf::test::fixed_width_column_wrapper{{-1, -1, 9, -1, 9, -1, 1, 0}, {0, 0, 1, 0, 1, 0, 1, 1}}; + auto struct_col = cudf::test::structs_column_wrapper{{col1, col2}, struct_valids}.release(); + // clang-format on + + for (auto const& test_case : { + // Non-null column + test_case_t{cudf::table_view{{list_col}}, cudf::table_view{{list_rank}}}, + // Null column + test_case_t{cudf::table_view{{struct_col->view()}}, cudf::table_view{{struct_rank}}}, + }) { + auto [input, output] = test_case; + + run_rank_test(input, + output, + method, + std::get<0>(input_arg), + std::get<1>(input_arg), + std::get<2>(input_arg), + percentage); + } + } +}; + +TYPED_TEST_SUITE(RankListAndStruct, cudf::test::NumericTypes); + +TYPED_TEST(RankListAndStruct, first_asc_keep) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper list_rank{ + {1, 3, 7, 9, 8, 4, 2, -1, 6, -1, 5}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{1, 2, -1, 5, -1, 6, 4, 3}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::FIRST, asc_keep, list_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, first_asc_top) +{ + // ASCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 3, 5, 9, 11, 10, 6, 4, 1, 8, 2, 7}; + cudf::test::fixed_width_column_wrapper struct_rank{7, 8, 1, 3, 2, 4, 6, 5}; + this->run_all_tests(cudf::rank_method::FIRST, asc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, first_asc_bottom) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 1, 3, 7, 9, 8, 4, 2, 10, 6, 11, 5}; + cudf::test::fixed_width_column_wrapper struct_rank{1, 2, 7, 5, 8, 6, 4, 3}; + this->run_all_tests(cudf::rank_method::FIRST, asc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, first_desc_keep) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + {8, 5, 2, 1, 3, 6, 9, -1, 4, -1, 7}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{2, 1, -1, 5, -1, 6, 3, 4}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::FIRST, desc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, first_desc_top) +{ + // DESCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 10, 7, 4, 3, 5, 8, 11, 1, 6, 2, 9}; + cudf::test::fixed_width_column_wrapper struct_rank{8, 7, 1, 3, 2, 4, 5, 6}; + this->run_all_tests(cudf::rank_method::FIRST, desc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, first_desc_bottom) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 8, 5, 2, 1, 3, 6, 9, 10, 4, 11, 7}; + cudf::test::fixed_width_column_wrapper struct_rank{2, 1, 7, 5, 8, 6, 3, 4}; + this->run_all_tests(cudf::rank_method::FIRST, desc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, dense_asc_keep) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + {1, 2, 4, 5, 4, 2, 1, -1, 3, -1, 2}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{1, 2, -1, 5, -1, 5, 4, 3}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::DENSE, asc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, dense_asc_top) +{ + // ASCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{2, 3, 5, 6, 5, 3, 2, 1, 4, 1, 3}; + cudf::test::fixed_width_column_wrapper struct_rank{5, 6, 1, 2, 1, 2, 4, 3}; + this->run_all_tests(cudf::rank_method::DENSE, asc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, dense_asc_bottom) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{1, 2, 4, 5, 4, 2, 1, 6, 3, 6, 2}; + cudf::test::fixed_width_column_wrapper struct_rank{1, 2, 6, 5, 6, 5, 4, 3}; + this->run_all_tests(cudf::rank_method::DENSE, asc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, dense_desc_keep) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + {5, 4, 2, 1, 2, 4, 5, -1, 3, -1, 4}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{2, 1, -1, 5, -1, 5, 3, 4}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::DENSE, desc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, dense_desc_top) +{ + // DESCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{6, 5, 3, 2, 3, 5, 6, 1, 4, 1, 5}; + cudf::test::fixed_width_column_wrapper struct_rank{6, 5, 1, 2, 1, 2, 3, 4}; + this->run_all_tests(cudf::rank_method::DENSE, desc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, dense_desc_bottom) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{5, 4, 2, 1, 2, 4, 5, 6, 3, 6, 4}; + cudf::test::fixed_width_column_wrapper struct_rank{2, 1, 6, 5, 6, 5, 3, 4}; + this->run_all_tests(cudf::rank_method::DENSE, desc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, min_asc_keep) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + {1, 3, 7, 9, 7, 3, 1, -1, 6, -1, 3}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{1, 2, -1, 5, -1, 5, 4, 3}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::MIN, asc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, min_asc_top) +{ + // ASCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 3, 5, 9, 11, 9, 5, 3, 1, 8, 1, 5}; + cudf::test::fixed_width_column_wrapper struct_rank{7, 8, 1, 3, 1, 3, 6, 5}; + this->run_all_tests(cudf::rank_method::MIN, asc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, min_asc_bottom) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 1, 3, 7, 9, 7, 3, 1, 10, 6, 10, 3}; + cudf::test::fixed_width_column_wrapper struct_rank{1, 2, 7, 5, 7, 5, 4, 3}; + this->run_all_tests(cudf::rank_method::MIN, asc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, min_desc_keep) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + {8, 5, 2, 1, 2, 5, 8, -1, 4, -1, 5}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{2, 1, -1, 5, -1, 5, 3, 4}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::MIN, desc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, min_desc_top) +{ + // DESCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 10, 7, 4, 3, 4, 7, 10, 1, 6, 1, 7}; + cudf::test::fixed_width_column_wrapper struct_rank{8, 7, 1, 3, 1, 3, 5, 6}; + this->run_all_tests(cudf::rank_method::MIN, desc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, min_desc_bottom) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 8, 5, 2, 1, 2, 5, 8, 10, 4, 10, 5}; + cudf::test::fixed_width_column_wrapper struct_rank{2, 1, 7, 5, 7, 5, 3, 4}; + this->run_all_tests(cudf::rank_method::MIN, desc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, max_asc_keep) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + {2, 5, 8, 9, 8, 5, 2, -1, 6, -1, 5}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{1, 2, -1, 6, -1, 6, 4, 3}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::MAX, asc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, max_asc_top) +{ + // ASCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 4, 7, 10, 11, 10, 7, 4, 2, 8, 2, 7}; + cudf::test::fixed_width_column_wrapper struct_rank{7, 8, 2, 4, 2, 4, 6, 5}; + this->run_all_tests(cudf::rank_method::MAX, asc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, max_asc_bottom) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 2, 5, 8, 9, 8, 5, 2, 11, 6, 11, 5}; + cudf::test::fixed_width_column_wrapper struct_rank{1, 2, 8, 6, 8, 6, 4, 3}; + this->run_all_tests(cudf::rank_method::MAX, asc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, max_desc_keep) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + {9, 7, 3, 1, 3, 7, 9, -1, 4, -1, 7}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{{2, 1, -1, 6, -1, 6, 3, 4}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::MAX, desc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, max_desc_top) +{ + // DESCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 11, 9, 5, 3, 5, 9, 11, 2, 6, 2, 9}; + cudf::test::fixed_width_column_wrapper struct_rank{8, 7, 2, 4, 2, 4, 5, 6}; + this->run_all_tests(cudf::rank_method::MAX, desc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, max_desc_bottom) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 9, 7, 3, 1, 3, 7, 9, 11, 4, 11, 7}; + cudf::test::fixed_width_column_wrapper struct_rank{2, 1, 8, 6, 8, 6, 3, 4}; + this->run_all_tests(cudf::rank_method::MAX, desc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, average_asc_keep) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + {1.5, 4.0, 7.5, 9.0, 7.5, 4.0, 1.5, -1.0, 6.0, -1.0, 4.0}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{ + {1.0, 2.0, -1.0, 5.5, -1.0, 5.5, 4.0, 3.0}, nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::AVERAGE, asc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, average_asc_top) +{ + // ASCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 3.5, 6.0, 9.5, 11.0, 9.5, 6.0, 3.5, 1.5, 8.0, 1.5, 6.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 7.0, 8.0, 1.5, 3.5, 1.5, 3.5, 6.0, 5.0}; + this->run_all_tests(cudf::rank_method::AVERAGE, asc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, average_asc_bottom) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 1.5, 4.0, 7.5, 9.0, 7.5, 4.0, 1.5, 10.5, 6.0, 10.5, 4.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 1.0, 2.0, 7.5, 5.5, 7.5, 5.5, 4.0, 3.0}; + this->run_all_tests(cudf::rank_method::AVERAGE, asc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, average_desc_keep) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + {8.5, 6.0, 2.5, 1.0, 2.5, 6.0, 8.5, -1.0, 4.0, -1.0, 6.0}, nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{ + {2.0, 1.0, -1.0, 5.5, -1.0, 5.5, 3.0, 4.0}, nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::AVERAGE, desc_keep, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, average_desc_top) +{ + // DESCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{ + 10.5, 8.0, 4.5, 3.0, 4.5, 8.0, 10.5, 1.5, 6.0, 1.5, 8.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 8.0, 7.0, 1.5, 3.5, 1.5, 3.5, 5.0, 6.0}; + this->run_all_tests(cudf::rank_method::AVERAGE, desc_top, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, average_desc_bottom) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{ + 8.5, 6.0, 2.5, 1.0, 2.5, 6.0, 8.5, 10.5, 4.0, 10.5, 6.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 2.0, 1.0, 7.5, 5.5, 7.5, 5.5, 3.0, 4.0}; + this->run_all_tests(cudf::rank_method::AVERAGE, desc_bottom, col_rank, struct_rank); +} + +TYPED_TEST(RankListAndStruct, dense_asc_keep_pct) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{{1.0 / 5.0, + 2.0 / 5.0, + 4.0 / 5.0, + 1.0, + 4.0 / 5.0, + 2.0 / 5.0, + 1.0 / 5.0, + -1.0, + 3.0 / 5.0, + -1.0, + 2.0 / 5.0}, + nulls_at({7, 9})}; + + cudf::test::fixed_width_column_wrapper struct_rank{ + {1.0 / 5.0, 2.0 / 5.0, -1.0, 1.0, -1.0, 1.0, 4.0 / 5.0, 3.0 / 5.0}, nulls_at({2, 4})}; + + this->run_all_tests(cudf::rank_method::DENSE, asc_keep, col_rank, struct_rank, true); +} + +TYPED_TEST(RankListAndStruct, dense_asc_top_pct) +{ + // ASCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{1.0 / 3.0, + 1.0 / 2.0, + 5.0 / 6.0, + 1.0, + 5.0 / 6.0, + 1.0 / 2.0, + 1.0 / 3.0, + 1.0 / 6.0, + 2.0 / 3.0, + 1.0 / 6.0, + 1.0 / 2.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 5.0 / 6.0, 1.0, 1.0 / 6.0, 2.0 / 6.0, 1.0 / 6.0, 2.0 / 6.0, 4.0 / 6.0, 3.0 / 6.0}; + this->run_all_tests(cudf::rank_method::DENSE, asc_top, col_rank, struct_rank, true); +} + +TYPED_TEST(RankListAndStruct, dense_asc_bottom_pct) +{ + // ASCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{1.0 / 6.0, + 1.0 / 3.0, + 2.0 / 3.0, + 5.0 / 6.0, + 2.0 / 3.0, + 1.0 / 3.0, + 1.0 / 6.0, + 1.0, + 1.0 / 2.0, + 1.0, + 1.0 / 3.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 1.0 / 6.0, 2.0 / 6.0, 1.0, 5.0 / 6.0, 1.0, 5.0 / 6.0, 4.0 / 6.0, 3.0 / 6.0}; + this->run_all_tests(cudf::rank_method::DENSE, asc_bottom, col_rank, struct_rank, true); +} + +TYPED_TEST(RankListAndStruct, min_desc_keep_pct) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{{8.0 / 9.0, + 5.0 / 9.0, + 2.0 / 9.0, + 1.0 / 9.0, + 2.0 / 9.0, + 5.0 / 9.0, + 8.0 / 9.0, + -1.0, + 4.0 / 9.0, + -1.0, + 5.0 / 9.0}, + nulls_at({7, 9})}; + cudf::test::fixed_width_column_wrapper struct_rank{ + {2.0 / 6.0, 1.0 / 6.0, -1.0, 5.0 / 6.0, -1.0, 5.0 / 6.0, 3.0 / 6.0, 4.0 / 6.0}, + nulls_at({2, 4})}; + this->run_all_tests(cudf::rank_method::MIN, desc_keep, col_rank, struct_rank, true); +} + +TYPED_TEST(RankListAndStruct, min_desc_top_pct) +{ + // DESCENDING and null_order::AFTER + cudf::test::fixed_width_column_wrapper col_rank{10.0 / 11.0, + 7.0 / 11.0, + 4.0 / 11.0, + 3.0 / 11.0, + 4.0 / 11.0, + 7.0 / 11.0, + 10.0 / 11.0, + 1.0 / 11.0, + 6.0 / 11.0, + 1.0 / 11.0, + 7.0 / 11.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 1.0, 7.0 / 8.0, 1.0 / 8.0, 3.0 / 8.0, 1.0 / 8.0, 3.0 / 8.0, 5.0 / 8.0, 6.0 / 8.0}; + this->run_all_tests(cudf::rank_method::MIN, desc_top, col_rank, struct_rank, true); +} + +TYPED_TEST(RankListAndStruct, min_desc_bottom_pct) +{ + // DESCENDING and null_order::BEFORE + cudf::test::fixed_width_column_wrapper col_rank{8.0 / 11.0, + 5.0 / 11.0, + 2.0 / 11.0, + 1.0 / 11.0, + 2.0 / 11.0, + 5.0 / 11.0, + 8.0 / 11.0, + 10.0 / 11.0, + 4.0 / 11.0, + 10.0 / 11.0, + 5.0 / 11.0}; + cudf::test::fixed_width_column_wrapper struct_rank{ + 2.0 / 8.0, 1.0 / 8.0, 7.0 / 8.0, 5.0 / 8.0, 7.0 / 8.0, 5.0 / 8.0, 3.0 / 8.0, 4.0 / 8.0}; + this->run_all_tests(cudf::rank_method::MIN, desc_bottom, col_rank, struct_rank, true); +} diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 4daa3c17cfc..2d0bf28225f 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2531,12 +2531,34 @@ public final ColumnVector stringLocate(Scalar substring, int start, int end) { * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + return stringSplit(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), limit); + } else { + return stringSplit(pattern, limit); + } + } + + /** + * Returns a list of columns by splitting each string using the specified regex program pattern. + * The number of rows in the output columns will be the same as the input column. Null entries + * are added for the rows where split results have been exhausted. Null input entries result in + * all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; - return new Table(stringSplit(this.getNativeView(), pattern, limit, splitByRegex)); + return new Table(stringSplitRe(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit)); } /** @@ -2550,6 +2572,7 @@ public final Table stringSplit(String pattern, int limit, boolean splitByRegex) * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, boolean splitByRegex) { return stringSplit(pattern, -1, splitByRegex); } @@ -2567,7 +2590,10 @@ public final Table stringSplit(String pattern, boolean splitByRegex) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter, int limit) { - return stringSplit(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be a String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new Table(stringSplit(this.getNativeView(), delimiter, limit)); } /** @@ -2580,7 +2606,21 @@ public final Table stringSplit(String delimiter, int limit) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter) { - return stringSplit(delimiter, -1, false); + return stringSplit(delimiter, -1); + } + + /** + * Returns a list of columns by splitting each string using the specified regex program pattern. + * The number of rows in the output columns will be the same as the input column. Null entries + * are added for the rows where split results have been exhausted. Null input entries result in + * all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg) { + return stringSplit(regexProg, -1); } /** @@ -2595,13 +2635,33 @@ public final Table stringSplit(String delimiter) { * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + return stringSplitRecord(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), limit); + } else { + return stringSplitRecord(pattern, limit); + } + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program pattern. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; return new ColumnVector( - stringSplitRecord(this.getNativeView(), pattern, limit, splitByRegex)); + stringSplitRecordRe(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit)); } /** @@ -2613,6 +2673,7 @@ public final ColumnVector stringSplitRecord(String pattern, int limit, boolean s * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex) { return stringSplitRecord(pattern, -1, splitByRegex); } @@ -2628,7 +2689,10 @@ public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter, int limit) { - return stringSplitRecord(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new ColumnVector(stringSplitRecord(this.getNativeView(), delimiter, limit)); } /** @@ -2639,7 +2703,19 @@ public final ColumnVector stringSplitRecord(String delimiter, int limit) { * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter) { - return stringSplitRecord(delimiter, -1, false); + return stringSplitRecord(delimiter, -1); + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program pattern. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg) { + return stringSplitRecord(regexProg, -1); } /** @@ -3958,36 +4034,64 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle private static native long substringLocate(long columnView, long substringScalar, int start, int end); /** - * Returns a list of columns by splitting each string using the specified pattern. The number of - * rows in the output columns will be the same as the input column. Null entries are added for a - * row where split results have been exhausted. Null input entries result in all nulls in the - * corresponding rows of the output columns. + * Returns a list of columns by splitting each string using the specified string literal + * delimiter. The number of rows in the output columns will be the same as the input column. + * Null entries are added for the rows where split results have been exhausted. Null input entries + * result in all nulls in the corresponding rows of the output columns. * * @param nativeHandle native handle of the input strings column that being operated on. - * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param delimiter UTF-8 encoded string identifying the split delimiter for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + */ + private static native long[] stringSplit(long nativeHandle, String delimiter, int limit); + + /** + * Returns a list of columns by splitting each string using the specified regular expression + * pattern. The number of rows in the output columns will be the same as the input column. + * Null entries are added for the rows where split results have been exhausted. Null input entries + * result in all nulls in the corresponding rows of the output columns. + * + * @param nativeHandle native handle of the input strings column that being operated on. + * @param pattern UTF-8 encoded string identifying the split regular expression pattern for + * each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. - * @param splitByRegex a boolean flag indicating whether the input strings will be split by a - * regular expression pattern or just by a string literal delimiter. */ - private static native long[] stringSplit(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long[] stringSplitRe(long nativeHandle, String pattern, int flags, + int capture, int limit); /** * Returns a column that are lists of strings in which each list is made by splitting the * corresponding input string using the specified string literal delimiter. * * @param nativeHandle native handle of the input strings column that being operated on. - * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param delimiter UTF-8 encoded string identifying the split delimiter for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + */ + private static native long stringSplitRecord(long nativeHandle, String delimiter, int limit); + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regular expression pattern. + * + * @param nativeHandle native handle of the input strings column that being operated on. + * @param pattern UTF-8 encoded string identifying the split regular expression pattern for + * each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. - * @param splitByRegex a boolean flag indicating whether the input strings will be split by a - * regular expression pattern or just by a string literal delimiter. */ - private static native long stringSplitRecord(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long stringSplitRecordRe(long nativeHandle, String pattern, int flags, + int capture, int limit); /** * Native method to calculate substring from a given string column. 0 indexing. diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index bfa3fa0a522..958efd364ed 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -681,9 +681,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reverseStringsOrLists(JNI JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { + jstring delimiter_obj, + jint limit) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -697,21 +696,42 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv * try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; + auto const input = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*input}; + auto const delimiter_jstr = cudf::jni::native_jstring(env, delimiter_obj); + auto const delimiter = std::string(delimiter_jstr.get(), delimiter_jstr.size_bytes()); + auto const max_split = limit > 1 ? limit - 1 : limit; + auto result = cudf::strings::split(strings_column, cudf::string_scalar{delimiter}, max_split); + return cudf::jni::convert_table_for_return(env, std::move(result)); + } + CATCH_STD(env, 0); +} - auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRe( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit) { + JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); + if (limit == 0 || limit == 1) { + // Cannot achieve the results of splitting with limit == 0 or limit == 1. + // This is because cudf operates on a different parameter (`max_split`) which is converted from + // limit. When limit == 0 or limit == 1, max_split will be non-positive and will result in an + // unlimited split. + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + "limit == 0 and limit == 1 are not supported", 0); + } + + try { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*input}; + auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); auto const max_split = limit > 1 ? limit - 1 : limit; - auto result = split_by_regex ? - cudf::strings::split_re(strs_input, pattern, max_split) : - cudf::strings::split(strs_input, cudf::string_scalar{pattern}, max_split); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); + auto result = cudf::strings::split_re(strings_column, *regex_prog, max_split); return cudf::jni::convert_table_for_return(env, std::move(result)); } CATCH_STD(env, 0); @@ -719,9 +739,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv * JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv *env, jclass, jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { + jstring delimiter_obj, + jint limit) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -735,22 +754,43 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; + auto const input = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*input}; + auto const delimiter_jstr = cudf::jni::native_jstring(env, delimiter_obj); + auto const delimiter = std::string(delimiter_jstr.get(), delimiter_jstr.size_bytes()); + auto const max_split = limit > 1 ? limit - 1 : limit; + auto result = + cudf::strings::split_record(strings_column, cudf::string_scalar{delimiter}, max_split); + return release_as_jlong(result); + } + CATCH_STD(env, 0); +} - auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecordRe( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit) { + JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); + if (limit == 0 || limit == 1) { + // Cannot achieve the results of splitting with limit == 0 or limit == 1. + // This is because cudf operates on a different parameter (`max_split`) which is converted from + // limit. When limit == 0 or limit == 1, max_split will be non-positive and will result in an + // unlimited split. + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + "limit == 0 and limit == 1 are not supported", 0); + } + + try { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*input}; + auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); auto const max_split = limit > 1 ? limit - 1 : limit; - auto result = - split_by_regex ? - cudf::strings::split_record_re(strs_input, pattern, max_split) : - cudf::strings::split_record(strs_input, cudf::string_scalar{pattern}, max_split); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); + auto result = cudf::strings::split_record_re(strings_column, *regex_prog, max_split); return release_as_jlong(result); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 46264b7d668..ab4baf74277 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4990,28 +4990,29 @@ void testReverseList() { void testStringSplit() { String pattern = " "; try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", - "ARé some things", "test strings here"); + "ARé some things", "test strings here"); Table expectedSplitLimit2 = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there all", null, null, null, "some things", "strings here") - .build(); + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some things", "strings here") + .build(); Table expectedSplitAll = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there", null, null, null, "some", "strings") - .column("all", null, null, null, "things", "here") - .build(); + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") + .build(); Table resultSplitLimit2 = v.stringSplit(pattern, 2); Table resultSplitAll = v.stringSplit(pattern)) { - assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); - assertTablesAreEqual(expectedSplitAll, resultSplitAll); + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); } } @Test void testStringSplitByRegularExpression() { String pattern = "[_ ]"; + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", - "ARé some_things", "test_strings_here"); + "ARé some_things", "test_strings_here"); Table expectedSplitLimit2 = new Table.TestBuilder() .column("Héllo", "thésé", null, "", "ARé", "test") .column("there all", null, null, null, "some_things", "strings_here") @@ -5020,11 +5021,17 @@ void testStringSplitByRegularExpression() { .column("Héllo", "thésé", null, "", "ARé", "test") .column("there", null, null, null, "some", "strings") .column("all", null, null, null, "things", "here") - .build(); - Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); - Table resultSplitAll = v.stringSplit(pattern, true)) { - assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); - assertTablesAreEqual(expectedSplitAll, resultSplitAll); + .build()) { + try (Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); + Table resultSplitAll = v.stringSplit(pattern, true)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } + try (Table resultSplitLimit2 = v.stringSplit(regexProg, 2); + Table resultSplitAll = v.stringSplit(regexProg)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } } } @@ -5032,7 +5039,7 @@ void testStringSplitByRegularExpression() { void testStringSplitRecord() { String pattern = " "; try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", - "ARé some things", "test strings here"); + "ARé some things", "test strings here"); ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.STRING)), @@ -5061,8 +5068,9 @@ void testStringSplitRecord() { @Test void testStringSplitRecordByRegularExpression() { String pattern = "[_ ]"; + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", - "ARé some_things", "test_strings_here"); + "ARé some_things", "test_strings_here"); ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.STRING)), @@ -5080,11 +5088,17 @@ void testStringSplitRecordByRegularExpression() { null, Arrays.asList(""), Arrays.asList("ARé", "some", "things"), - Arrays.asList("test", "strings", "here")); - ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); - ColumnVector resultSplitAll = v.stringSplitRecord(pattern, true)) { - assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); - assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + Arrays.asList("test", "strings", "here"))) { + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); + ColumnVector resultSplitAll = v.stringSplitRecord(pattern, true)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(regexProg, 2); + ColumnVector resultSplitAll = v.stringSplitRecord(regexProg)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } } }