Skip to content

Commit

Permalink
Replace thrust::max_element with thrust::reduce in strings findall_re (
Browse files Browse the repository at this point in the history
…#7428)

This is a cleanup of `findall.cu` to change from using `thrust::max_element` to the more efficient `thrust::reduce` with a `thrust::maximum` operator. This also changes `device_vector` usage to `device_uvector` and adds more `const` to variable decls.

Authors:
  - David (@davidwendt)

Approvers:
  - Karthikeyan (@karthikeyann)
  - Conor Hoekstra (@codereport)

URL: #7428
  • Loading branch information
davidwendt authored Feb 25, 2021
1 parent d4583ec commit 49f6857
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions cpp/src/strings/findall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/findall.hpp>
#include <cudf/strings/string_view.cuh>
Expand Down Expand Up @@ -118,42 +118,43 @@ std::unique_ptr<table> findall_re(
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource(),
rmm::cuda_stream_view stream = rmm::cuda_stream_default)
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_strings = *strings_column;
auto const strings_count = strings.size();
auto const d_strings = column_device_view::create(strings.parent(), stream);

auto d_flags = detail::get_character_flags_table();
auto const d_flags = detail::get_character_flags_table();
// compile regex into device object
auto prog = reprog_device::create(pattern, d_flags, strings_count, stream);
auto d_prog = *prog;
int regex_insts = prog->insts_counts();
auto const d_prog = reprog_device::create(pattern, d_flags, strings_count, stream);
auto const regex_insts = d_prog->insts_counts();

rmm::device_vector<size_type> find_counts(strings_count);
auto d_find_counts = find_counts.data().get();
rmm::device_uvector<size_type> find_counts(strings_count, stream);
auto d_find_counts = find_counts.data();

if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_SMALL>{d_strings, d_prog});
findall_count_fn<RX_STACK_SMALL>{*d_strings, *d_prog});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_MEDIUM>{d_strings, d_prog});
findall_count_fn<RX_STACK_MEDIUM>{*d_strings, *d_prog});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_LARGE>{d_strings, d_prog});
findall_count_fn<RX_STACK_LARGE>{*d_strings, *d_prog});

std::vector<std::unique_ptr<column>> results;

size_type columns =
*thrust::max_element(rmm::exec_policy(stream), find_counts.begin(), find_counts.end());
size_type const columns = thrust::reduce(rmm::exec_policy(stream),
find_counts.begin(),
find_counts.end(),
0,
thrust::maximum<size_type>{});
// boundary case: if no columns, return all nulls column (issue #119)
if (columns == 0)
results.emplace_back(std::make_unique<column>(
Expand All @@ -164,30 +165,32 @@ std::unique_ptr<table> findall_re(
strings_count));

for (int32_t column_index = 0; column_index < columns; ++column_index) {
rmm::device_vector<string_index_pair> indices(strings_count);
string_index_pair* d_indices = indices.data().get();
rmm::device_uvector<string_index_pair> indices(strings_count, stream);
auto d_indices = indices.data();

if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_SMALL>{d_strings, d_prog, column_index, d_find_counts});
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_SMALL>{*d_strings, *d_prog, column_index, d_find_counts});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_MEDIUM>{d_strings, d_prog, column_index, d_find_counts});
findall_fn<RX_STACK_MEDIUM>{*d_strings, *d_prog, column_index, d_find_counts});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_LARGE>{d_strings, d_prog, column_index, d_find_counts});
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_LARGE>{*d_strings, *d_prog, column_index, d_find_counts});
//
results.emplace_back(make_strings_column(indices, stream, mr));
results.emplace_back(make_strings_column(indices.begin(), indices.end(), stream, mr));
}
return std::make_unique<table>(std::move(results));
}
Expand Down

0 comments on commit 49f6857

Please sign in to comment.