Skip to content

Commit

Permalink
Use single kernel to extract all groups in cudf::strings::extract (#9358
Browse files Browse the repository at this point in the history
)

This is a less ambitious version of #8460 which had to be reverted in #8575 because it did not work with greedy quantifiers. The change here involves calling the underlying `reprog_device::extract` to retrieve each group result within a single kernel rather than launching a kernel for each group. The output is placed contiguously in a 2d span (wrapped uvector) and a permutation iterator is used to build the output columns (one column per group).

Like it's predecessor, the performance improvement is mostly when specifying more than 1 group in the regex pattern. The benchmark results showed no change for single groups but was 2x faster for multiple groups over long (8K) strings and up to 4x faster for multiple groups over many (16M) strings.

The benchmark test for extract was also updated to better report the number of groups being used when measuring results.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Mark Harris (https://github.com/harrism)
  - Nghia Truong (https://github.com/ttnghia)

URL: #9358
  • Loading branch information
davidwendt authored Oct 13, 2021
1 parent df27da2 commit a4f6c6d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 61 deletions.
11 changes: 6 additions & 5 deletions cpp/benchmarks/string/extract_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static void BM_extract(benchmark::State& state, int groups)
return row;
});

std::string pattern;
while (static_cast<int>(pattern.size()) < groups) {
std::string pattern{""};
while (groups--) {
pattern += "(\\d+) ";
}

Expand Down Expand Up @@ -86,6 +86,7 @@ static void generate_bench_args(benchmark::internal::Benchmark* b)
->UseManualTime() \
->Unit(benchmark::kMillisecond);

STRINGS_BENCHMARK_DEFINE(small, 2)
STRINGS_BENCHMARK_DEFINE(medium, 10)
STRINGS_BENCHMARK_DEFINE(large, 30)
STRINGS_BENCHMARK_DEFINE(one, 1)
STRINGS_BENCHMARK_DEFINE(two, 2)
STRINGS_BENCHMARK_DEFINE(four, 4)
STRINGS_BENCHMARK_DEFINE(eight, 8)
128 changes: 72 additions & 56 deletions cpp/src/strings/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

Expand All @@ -47,29 +47,36 @@ using string_index_pair = thrust::pair<const char*, size_type>;
template <int stack_size>
struct extract_fn {
reprog_device prog;
column_device_view d_strings;
size_type column_index;
column_device_view const d_strings;
cudf::detail::device_2dspan<string_index_pair> d_indices;

__device__ string_index_pair operator()(size_type idx)
__device__ void operator()(size_type idx)
{
if (d_strings.is_null(idx)) return string_index_pair{nullptr, 0};
string_view d_str = d_strings.element<string_view>(idx);
string_index_pair result{nullptr, 0};
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if (prog.find<stack_size>(idx, d_str, begin, end) > 0) {
auto extracted = prog.extract<stack_size>(idx, d_str, begin, end, column_index);
if (extracted) {
auto const offset = d_str.byte_offset(extracted.value().first);
// build index-pair
result = string_index_pair{d_str.data() + offset,
d_str.byte_offset(extracted.value().second) - offset};
auto const groups = prog.group_counts();
auto d_output = d_indices[idx];

if (d_strings.is_valid(idx)) {
auto const d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if (prog.find<stack_size>(idx, d_str, begin, end) > 0) {
for (auto col_idx = 0; col_idx < groups; ++col_idx) {
auto const extracted = prog.extract<stack_size>(idx, d_str, begin, end, col_idx);
d_output[col_idx] = [&] {
if (!extracted) return string_index_pair{nullptr, 0};
auto const offset = d_str.byte_offset((*extracted).first);
return string_index_pair{d_str.data() + offset,
d_str.byte_offset((*extracted).second) - offset};
}();
}
return;
}
}
return result;

// if null row or no match found, fill the output with null entries
thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0});
}
};

} // namespace

//
Expand All @@ -79,9 +86,9 @@ std::unique_ptr<table> extract(
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_strings = *strings_column;
auto const strings_count = strings.size();
auto const strings_column = column_device_view::create(strings.parent(), stream);
auto const d_strings = *strings_column;

// compile regex into device object
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
Expand All @@ -90,41 +97,50 @@ std::unique_ptr<table> extract(
auto const groups = d_prog.group_counts();
CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern");

rmm::device_uvector<string_index_pair> indices(strings_count * groups, stream);
cudf::detail::device_2dspan<string_index_pair> d_indices(indices.data(), strings_count, groups);

auto const regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS) {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, d_indices});
} else {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_ANY>{d_prog, d_strings, d_indices});
}

// build a result column for each group
std::vector<std::unique_ptr<column>> results;
auto regex_insts = d_prog.insts_counts();

for (int32_t column_index = 0; column_index < groups; ++column_index) {
rmm::device_uvector<string_index_pair> indices(strings_count, stream);

if (regex_insts <= RX_SMALL_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, column_index});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, column_index});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, column_index});
} else {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, column_index});
}
std::vector<std::unique_ptr<column>> results(groups);
auto make_strings_lambda = [&](size_type column_index) {
// this iterator transposes the extract results into column order
auto indices_itr =
thrust::make_permutation_iterator(indices.begin(),
cudf::detail::make_counting_transform_iterator(
0, [column_index, groups] __device__(size_type idx) {
return (idx * groups) + column_index;
}));
return make_strings_column(indices_itr, indices_itr + strings_count, stream, mr);
};

std::transform(thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(groups),
results.begin(),
make_strings_lambda);

results.emplace_back(make_strings_column(indices, stream, mr));
}
return std::make_unique<table>(std::move(results));
}

Expand Down

0 comments on commit a4f6c6d

Please sign in to comment.