Skip to content

Commit

Permalink
Modify reprog_device::extract to return groups in a single pass (#8460)
Browse files Browse the repository at this point in the history
This PR modifies the internal regex `reprog_device::extract` function to return all matching groups in a single call. Previously, retrieving each group range required individual calls to this `extract` function resulted in re-matching the entire given pattern for each group. The code logic would identify each group but only return the range for the specified group. 

The code change here passes a pre-allocated global memory array to capture each group range in a single pass. The extract is an all-or-nothing process. In fact, a `find` function must first be executed to retrieve the bounds of the given pattern. So if any of the groups are missing or do not match, no groups are returned for that row. Retrieving the last group would always require processing the previous groups and the code logic now records those positions in the global memory array. The memory array can then be used directly to build the output columns.

This simplifies the code around extract and also improves performance especially for long strings or patterns with many groups. For small strings and a small number of groups, the gbenchmark showed equivalent performance to the previous implementation. For larger strings and more groups, the gbenchmark showed a 2-3x improvement.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Robert Maynard (https://github.com/robertmaynard)
  - Christopher Harris (https://github.com/cwharris)

URL: #8460
  • Loading branch information
davidwendt authored Jun 18, 2021
1 parent 0099f11 commit d183d50
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 228 deletions.
2 changes: 0 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,6 @@ add_library(cudf
src/strings/regex/regexec.cu
src/strings/repeat_strings.cu
src/strings/replace/backref_re.cu
src/strings/replace/backref_re_large.cu
src/strings/replace/backref_re_medium.cu
src/strings/replace/multi_re.cu
src/strings/replace/replace.cu
src/strings/replace/replace_re.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace {
* Small to medium instruction lengths can use the stack effectively though smaller executes faster.
* Longer patterns require global memory.
*/
template <size_t stack_size>
template <int stack_size>
struct contains_fn {
reprog_device prog;
column_device_view d_strings;
Expand Down Expand Up @@ -163,7 +163,7 @@ namespace {
/**
* @brief This counts the number of times the regex pattern matches in each string.
*/
template <size_t stack_size>
template <int stack_size>
struct count_fn {
reprog_device prog;
column_device_view d_strings;
Expand Down
97 changes: 53 additions & 44 deletions cpp/src/strings/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <thrust/fill.h>

namespace cudf {
namespace strings {
namespace detail {
using string_index_pair = thrust::pair<const char*, size_type>;

namespace {
/**
Expand All @@ -42,26 +45,27 @@ namespace {
* @tparam stack_size Correlates to the regex instructions state to maintain for each string.
* Each instruction requires a fixed amount of overhead data.
*/
template <size_t stack_size>
template <int stack_size>
struct extract_fn {
reprog_device prog;
column_device_view d_strings;
size_type column_index;
cudf::detail::device_2dspan<string_index_pair> d_indices;

__device__ string_index_pair operator()(size_type idx)
__device__ void operator()(size_type idx)
{
if (d_strings.is_null(idx)) return string_index_pair{nullptr, 0};
string_view d_str = d_strings.element<string_view>(idx);
string_index_pair result{nullptr, 0};
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if ((prog.find<stack_size>(idx, d_str, begin, end) > 0) &&
(prog.extract<stack_size>(idx, d_str, begin, end, column_index) > 0)) {
auto offset = d_str.byte_offset(begin);
// build index-pair
result = string_index_pair{d_str.data() + offset, d_str.byte_offset(end) - offset};
auto groups = prog.group_counts();
auto d_output = d_indices[idx];
if (d_strings.is_valid(idx)) {
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if ((prog.find<stack_size>(idx, d_str, begin, end) > 0) &&
prog.extract<stack_size>(idx, d_str, begin, end, d_output)) {
return;
}
}
return result;
// fill output with null entries
thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0});
}
};

Expand All @@ -82,43 +86,48 @@ std::unique_ptr<table> extract(
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;
// extract should include groups
int groups = d_prog.group_counts();
auto const groups = d_prog.group_counts();
CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern");

// build a result column for each group
std::vector<std::unique_ptr<column>> results;
auto regex_insts = d_prog.insts_counts();

rmm::device_uvector<string_index_pair> indices(strings_count * groups, stream);
cudf::detail::device_2dspan<string_index_pair> d_indices(indices.data(), strings_count, groups);

if (regex_insts <= RX_SMALL_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, d_indices});
} else { // supports any number of instructions
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, d_indices});
}

for (int32_t column_index = 0; column_index < groups; ++column_index) {
rmm::device_uvector<string_index_pair> indices(strings_count, stream);

if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, column_index});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, column_index});
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, column_index});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, column_index});

results.emplace_back(make_strings_column(indices, stream, mr));
auto indices_itr = thrust::make_permutation_iterator(
indices.begin(),
thrust::make_transform_iterator(thrust::make_counting_iterator<size_type>(0),
[column_index, groups] __device__(size_type idx) {
return (idx * groups) + column_index;
}));
results.emplace_back(make_strings_column(indices_itr, indices_itr + strings_count, stream, mr));
}

return std::make_unique<table>(std::move(results));
}

Expand Down
39 changes: 29 additions & 10 deletions cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@

#include <strings/regex/regcomp.h>

#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <thrust/pair.h>

#include <functional>
#include <memory>

Expand All @@ -33,6 +38,8 @@ struct reljunk;
struct reinst;
class reprog;

using string_index_pair = thrust::pair<const char*, cudf::size_type>;

constexpr int32_t RX_STACK_SMALL = 112; ///< fastest stack size
constexpr int32_t RX_STACK_MEDIUM = 1104; ///< faster stack size
constexpr int32_t RX_STACK_LARGE = 10128; ///< fast stack size
Expand Down Expand Up @@ -99,6 +106,7 @@ class reprog_device {
const uint8_t* cp_flags,
int32_t strings_count,
rmm::cuda_stream_view stream);

/**
* @brief Called automatically by the unique_ptr returned from create().
*/
Expand Down Expand Up @@ -157,21 +165,26 @@ class reprog_device {
* @brief Does an extract evaluation using the compiled expression on the given string.
*
* This will find a specific match within the string when more than match occurs.
* The find() function should be called first to locate the begin/end bounds of the
* the matched section.
*
* @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`.
* @param idx The string index used for mapping the state memory for this string in global memory
* (if necessary).
* @param d_str The string to search.
* @param[in,out] begin Position index to begin the search. If found, returns the position found
* @param begin Position index to begin the search. If found, returns the position found
* in the string.
* @param[in,out] end Position index to end the search. If found, returns the last position
* @param end Position index to end the search. If found, returns the last position
* matching in the string.
* @param group_id The specific instance to return if more than one match is found.
* @return Returns 0 if no match is found.
* @param indices All extracted groups
* @return Returns true if successful.
*/
template <int stack_size>
__device__ inline int32_t extract(
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t group_id);
__device__ inline bool extract(int32_t idx,
string_view const& d_str,
int32_t begin,
int32_t end,
device_span<string_index_pair> indices);

private:
int32_t _startinst_id, _num_capturing_groups;
Expand All @@ -185,15 +198,21 @@ class reprog_device {
/**
* @brief Executes the regex pattern on the given string.
*/
__device__ inline int32_t regexec(
string_view const& d_str, reljunk& jnk, int32_t& begin, int32_t& end, int32_t groupid = 0);
__device__ inline int32_t regexec(string_view const& d_str,
reljunk& jnk,
int32_t& begin,
int32_t& end,
string_index_pair* indices = nullptr);

/**
* @brief Utility wrapper to setup state memory structures for calling regexec
*/
template <int stack_size>
__device__ inline int32_t call_regexec(
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t groupid = 0);
__device__ inline int32_t call_regexec(int32_t idx,
string_view const& d_str,
int32_t& begin,
int32_t& end,
string_index_pair* indices = nullptr);

reprog_device(reprog&); // must use create()
};
Expand Down
47 changes: 29 additions & 18 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ __device__ inline int32_t* reprog_device::startinst_ids() const { return _starti
* @return >0 if match found
*/
__device__ inline int32_t reprog_device::regexec(
string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id)
string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, string_index_pair* indices)
{
int32_t match = 0;
auto checkstart = jnk.starttype;
Expand Down Expand Up @@ -229,10 +229,9 @@ __device__ inline int32_t reprog_device::regexec(
}

if (((eos < 0) || (pos < eos)) && match == 0) {
// jnk.list1->activate(startinst_id, pos, 0);
int32_t i = 0;
auto ids = startinst_ids();
while (ids[i] >= 0) jnk.list1->activate(ids[i++], (group_id == 0 ? pos : -1), -1);
while (ids[i] >= 0) jnk.list1->activate(ids[i++], (indices == nullptr ? pos : -1), -1);
}

c = static_cast<char32_t>(pos >= txtlen ? 0 : *itr);
Expand All @@ -257,14 +256,20 @@ __device__ inline int32_t reprog_device::regexec(
case NCCLASS:
case END: id_activate = inst_id; break;
case LBRA:
if (inst->u1.subid == group_id) range.x = pos;
if (indices && inst->u1.subid == _num_capturing_groups) range.x = pos;
id_activate = inst->u2.next_id;
expanded = true;
if (indices) { indices[inst->u1.subid - 1].first = dstr.data() + itr.byte_offset(); }
break;
case RBRA:
if (inst->u1.subid == group_id) range.y = pos;
if (indices && inst->u1.subid == _num_capturing_groups) range.y = pos;
id_activate = inst->u2.next_id;
expanded = true;
if (indices) {
auto const ptr_offset = indices[inst->u1.subid - 1].first - dstr.data();
indices[inst->u1.subid - 1].second =
itr.byte_offset() - static_cast<cudf::size_type>(ptr_offset);
}
break;
case BOL:
if ((pos == 0) ||
Expand Down Expand Up @@ -318,8 +323,9 @@ __device__ inline int32_t reprog_device::regexec(
} while (expanded);

// execute
bool continue_execute = true;
jnk.list2->reset();
for (int16_t i = 0; i < jnk.list1->size; i++) {
for (int16_t i = 0; continue_execute && i < jnk.list1->size; i++) {
int32_t inst_id = static_cast<int32_t>(jnk.list1->inst_ids[i]);
int2& range = jnk.list1->ranges[i];
const reinst* inst = get_inst(inst_id);
Expand All @@ -346,18 +352,21 @@ __device__ inline int32_t reprog_device::regexec(
case END:
match = 1;
begin = range.x;
end = group_id == 0 ? pos : range.y;
goto BreakFor;
end = indices == nullptr ? pos : range.y;

continue_execute = false;
break;
}
if (id_activate >= 0) jnk.list2->activate(id_activate, range.x, range.y);
if (continue_execute && (id_activate >= 0))
jnk.list2->activate(id_activate, range.x, range.y);
}

BreakFor:
++pos;
++itr;
swaplist(jnk.list1, jnk.list2);
checkstart = jnk.list1->size > 0 ? 0 : 1;
} while (c && (jnk.list1->size > 0 || match == 0));

return match;
}

Expand All @@ -373,16 +382,19 @@ __device__ inline int32_t reprog_device::find(int32_t idx,
}

template <int stack_size>
__device__ inline int32_t reprog_device::extract(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
__device__ inline bool reprog_device::extract(int32_t idx,
string_view const& dstr,
int32_t begin,
int32_t end,
device_span<string_index_pair> indices)
{
end = begin + 1;
return call_regexec<stack_size>(idx, dstr, begin, end, group_id + 1);
return call_regexec<stack_size>(idx, dstr, begin, end, indices.data()) > 0;
}

template <int stack_size>
__device__ inline int32_t reprog_device::call_regexec(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices)
{
u_char data1[stack_size], data2[stack_size];

Expand All @@ -393,12 +405,12 @@ __device__ inline int32_t reprog_device::call_regexec(
relist list2(static_cast<int16_t>(_insts_count), data2);

reljunk jnk(&list1, &list2, stype, schar);
return regexec(dstr, jnk, begin, end, group_id);
return regexec(dstr, jnk, begin, end, indices);
}

template <>
__device__ inline int32_t reprog_device::call_regexec<RX_STACK_ANY>(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices)
{
auto const stype = get_inst(_startinst_id)->type;
auto const schar = get_inst(_startinst_id)->u1.c;
Expand All @@ -407,12 +419,11 @@ __device__ inline int32_t reprog_device::call_regexec<RX_STACK_ANY>(
u_char* listmem = reinterpret_cast<u_char*>(_relists_mem); // beginning of relist buffer;
listmem += (idx * relists_size * 2); // two relist ptrs in reljunk:

// run ctor on assigned memory buffer
relist* list1 = new (listmem) relist(static_cast<int16_t>(_insts_count));
relist* list2 = new (listmem + relists_size) relist(static_cast<int16_t>(_insts_count));

reljunk jnk(list1, list2, stype, schar);
return regexec(dstr, jnk, begin, end, group_id);
return regexec(dstr, jnk, begin, end, indices);
}

} // namespace detail
Expand Down
Loading

0 comments on commit d183d50

Please sign in to comment.