diff --git a/cpp/src/strings/findall.cu b/cpp/src/strings/findall.cu index b9f2f7046a3..2c26875b5d6 100644 --- a/cpp/src/strings/findall.cu +++ b/cpp/src/strings/findall.cu @@ -16,9 +16,9 @@ #include #include -#include #include #include +#include #include #include #include @@ -118,42 +118,43 @@ std::unique_ptr findall_re( rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource(), rmm::cuda_stream_view stream = rmm::cuda_stream_default) { - auto strings_count = strings.size(); - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_strings = *strings_column; + auto const strings_count = strings.size(); + auto const d_strings = column_device_view::create(strings.parent(), stream); - auto d_flags = detail::get_character_flags_table(); + auto const d_flags = detail::get_character_flags_table(); // compile regex into device object - auto prog = reprog_device::create(pattern, d_flags, strings_count, stream); - auto d_prog = *prog; - int regex_insts = prog->insts_counts(); + auto const d_prog = reprog_device::create(pattern, d_flags, strings_count, stream); + auto const regex_insts = d_prog->insts_counts(); - rmm::device_vector find_counts(strings_count); - auto d_find_counts = find_counts.data().get(); + rmm::device_uvector find_counts(strings_count, stream); + auto d_find_counts = find_counts.data(); if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS)) thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), d_find_counts, - findall_count_fn{d_strings, d_prog}); + findall_count_fn{*d_strings, *d_prog}); else if (regex_insts <= RX_MEDIUM_INSTS) thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), d_find_counts, - findall_count_fn{d_strings, d_prog}); + findall_count_fn{*d_strings, *d_prog}); else thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), d_find_counts, - findall_count_fn{d_strings, d_prog}); + findall_count_fn{*d_strings, *d_prog}); std::vector> results; - size_type columns = - *thrust::max_element(rmm::exec_policy(stream), find_counts.begin(), find_counts.end()); + size_type const columns = thrust::reduce(rmm::exec_policy(stream), + find_counts.begin(), + find_counts.end(), + 0, + thrust::maximum{}); // boundary case: if no columns, return all nulls column (issue #119) if (columns == 0) results.emplace_back(std::make_unique( @@ -164,30 +165,32 @@ std::unique_ptr
findall_re( strings_count)); for (int32_t column_index = 0; column_index < columns; ++column_index) { - rmm::device_vector indices(strings_count); - string_index_pair* d_indices = indices.data().get(); + rmm::device_uvector indices(strings_count, stream); + auto d_indices = indices.data(); if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS)) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_indices, - findall_fn{d_strings, d_prog, column_index, d_find_counts}); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_indices, + findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); else if (regex_insts <= RX_MEDIUM_INSTS) thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), d_indices, - findall_fn{d_strings, d_prog, column_index, d_find_counts}); + findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); else - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_indices, - findall_fn{d_strings, d_prog, column_index, d_find_counts}); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_indices, + findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); // - results.emplace_back(make_strings_column(indices, stream, mr)); + results.emplace_back(make_strings_column(indices.begin(), indices.end(), stream, mr)); } return std::make_unique
(std::move(results)); }