Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in replace_with_backrefs when group has greedy quantifier #8575

Merged
merged 4 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 47 additions & 50 deletions cpp/src/strings/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,21 @@
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <thrust/fill.h>

namespace cudf {
namespace strings {
namespace detail {

namespace {

using string_index_pair = thrust::pair<const char*, size_type>;

/**
* @brief This functor handles extracting strings by applying the compiled regex pattern
* and creating string_index_pairs for all the substrings.
Expand All @@ -49,23 +48,25 @@ template <int stack_size>
struct extract_fn {
reprog_device prog;
column_device_view d_strings;
cudf::detail::device_2dspan<string_index_pair> d_indices;
size_type column_index;

__device__ void operator()(size_type idx)
__device__ string_index_pair operator()(size_type idx)
{
auto groups = prog.group_counts();
auto d_output = d_indices[idx];
if (d_strings.is_valid(idx)) {
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if ((prog.find<stack_size>(idx, d_str, begin, end) > 0) &&
prog.extract<stack_size>(idx, d_str, begin, end, d_output)) {
return;
if (d_strings.is_null(idx)) return string_index_pair{nullptr, 0};
string_view d_str = d_strings.element<string_view>(idx);
string_index_pair result{nullptr, 0};
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if (prog.find<stack_size>(idx, d_str, begin, end) > 0) {
auto extracted = prog.extract<stack_size>(idx, d_str, begin, end, column_index);
if (extracted) {
auto const offset = d_str.byte_offset(extracted.value().first);
// build index-pair
result = string_index_pair{d_str.data() + offset,
d_str.byte_offset(extracted.value().second) - offset};
}
}
// fill output with null entries
thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0});
return result;
}
};

Expand Down Expand Up @@ -93,41 +94,37 @@ std::unique_ptr<table> extract(
std::vector<std::unique_ptr<column>> results;
auto regex_insts = d_prog.insts_counts();

rmm::device_uvector<string_index_pair> indices(strings_count * groups, stream);
cudf::detail::device_2dspan<string_index_pair> d_indices(indices.data(), strings_count, groups);

if (regex_insts <= RX_SMALL_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, d_indices});
} else { // supports any number of instructions
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, d_indices});
}

for (int32_t column_index = 0; column_index < groups; ++column_index) {
auto indices_itr = thrust::make_permutation_iterator(
indices.begin(),
thrust::make_transform_iterator(thrust::make_counting_iterator<size_type>(0),
[column_index, groups] __device__(size_type idx) {
return (idx * groups) + column_index;
}));
results.emplace_back(make_strings_column(indices_itr, indices_itr + strings_count, stream, mr));
}
rmm::device_uvector<string_index_pair> indices(strings_count, stream);

if (regex_insts <= RX_SMALL_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, column_index});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, column_index});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, column_index});
} else {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, column_index});
}

results.emplace_back(make_strings_column(indices, stream, mr));
}
return std::make_unique<table>(std::move(results));
}

Expand Down
33 changes: 14 additions & 19 deletions cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include <strings/regex/regcomp.h>

#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <thrust/optional.h>
#include <thrust/pair.h>

#include <functional>
Expand All @@ -38,7 +38,8 @@ struct reljunk;
struct reinst;
class reprog;

using string_index_pair = thrust::pair<const char*, cudf::size_type>;
using match_pair = thrust::pair<cudf::size_type, cudf::size_type>;
using match_result = thrust::optional<match_pair>;

constexpr int32_t RX_STACK_SMALL = 112; ///< fastest stack size
constexpr int32_t RX_STACK_MEDIUM = 1104; ///< faster stack size
Expand Down Expand Up @@ -176,15 +177,15 @@ class reprog_device {
* in the string.
* @param end Position index to end the search. If found, returns the last position
* matching in the string.
* @param indices All extracted groups
* @return Returns true if successful.
* @param group_id The specific group to return its matching position values.
* @return If valid, returns the character position of the matched group in the given string,
*/
template <int stack_size>
__device__ inline bool extract(int32_t idx,
string_view const& d_str,
int32_t begin,
int32_t end,
device_span<string_index_pair> indices);
__device__ inline match_result extract(cudf::size_type idx,
string_view const& d_str,
cudf::size_type begin,
cudf::size_type end,
cudf::size_type group_id);

private:
int32_t _startinst_id, _num_capturing_groups;
Expand All @@ -198,21 +199,15 @@ class reprog_device {
/**
* @brief Executes the regex pattern on the given string.
*/
__device__ inline int32_t regexec(string_view const& d_str,
reljunk& jnk,
int32_t& begin,
int32_t& end,
string_index_pair* indices = nullptr);
__device__ inline int32_t regexec(
string_view const& d_str, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id = 0);

/**
* @brief Utility wrapper to setup state memory structures for calling regexec
*/
template <int stack_size>
__device__ inline int32_t call_regexec(int32_t idx,
string_view const& d_str,
int32_t& begin,
int32_t& end,
string_index_pair* indices = nullptr);
__device__ inline int32_t call_regexec(
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t group_id = 0);

reprog_device(reprog&); // must use create()
};
Expand Down
38 changes: 17 additions & 21 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ __device__ inline int32_t* reprog_device::startinst_ids() const { return _starti
* @return >0 if match found
*/
__device__ inline int32_t reprog_device::regexec(
string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, string_index_pair* indices)
string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id)
{
int32_t match = 0;
auto checkstart = jnk.starttype;
Expand Down Expand Up @@ -231,7 +231,7 @@ __device__ inline int32_t reprog_device::regexec(
if (((eos < 0) || (pos < eos)) && match == 0) {
int32_t i = 0;
auto ids = startinst_ids();
while (ids[i] >= 0) jnk.list1->activate(ids[i++], (indices == nullptr ? pos : -1), -1);
while (ids[i] >= 0) jnk.list1->activate(ids[i++], (group_id == 0 ? pos : -1), -1);
}

c = static_cast<char32_t>(pos >= txtlen ? 0 : *itr);
Expand All @@ -256,20 +256,14 @@ __device__ inline int32_t reprog_device::regexec(
case NCCLASS:
case END: id_activate = inst_id; break;
case LBRA:
if (indices && inst->u1.subid == _num_capturing_groups) range.x = pos;
if (inst->u1.subid == group_id) range.x = pos;
id_activate = inst->u2.next_id;
expanded = true;
if (indices) { indices[inst->u1.subid - 1].first = dstr.data() + itr.byte_offset(); }
break;
case RBRA:
if (indices && inst->u1.subid == _num_capturing_groups) range.y = pos;
if (inst->u1.subid == group_id) range.y = pos;
id_activate = inst->u2.next_id;
expanded = true;
if (indices) {
auto const ptr_offset = indices[inst->u1.subid - 1].first - dstr.data();
indices[inst->u1.subid - 1].second =
itr.byte_offset() - static_cast<cudf::size_type>(ptr_offset);
}
break;
case BOL:
if ((pos == 0) ||
Expand Down Expand Up @@ -352,7 +346,7 @@ __device__ inline int32_t reprog_device::regexec(
case END:
match = 1;
begin = range.x;
end = indices == nullptr ? pos : range.y;
end = group_id == 0 ? pos : range.y;

continue_execute = false;
break;
Expand Down Expand Up @@ -382,19 +376,21 @@ __device__ inline int32_t reprog_device::find(int32_t idx,
}

template <int stack_size>
__device__ inline bool reprog_device::extract(int32_t idx,
string_view const& dstr,
int32_t begin,
int32_t end,
device_span<string_index_pair> indices)
__device__ inline match_result reprog_device::extract(cudf::size_type idx,
string_view const& dstr,
cudf::size_type begin,
cudf::size_type end,
cudf::size_type group_id)
{
end = begin + 1;
return call_regexec<stack_size>(idx, dstr, begin, end, indices.data()) > 0;
return call_regexec<stack_size>(idx, dstr, begin, end, group_id + 1) > 0
? match_result({begin, end})
: thrust::nullopt;
}

template <int stack_size>
__device__ inline int32_t reprog_device::call_regexec(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices)
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
{
u_char data1[stack_size], data2[stack_size];

Expand All @@ -405,12 +401,12 @@ __device__ inline int32_t reprog_device::call_regexec(
relist list2(static_cast<int16_t>(_insts_count), data2);

reljunk jnk(&list1, &list2, stype, schar);
return regexec(dstr, jnk, begin, end, indices);
return regexec(dstr, jnk, begin, end, group_id);
}

template <>
__device__ inline int32_t reprog_device::call_regexec<RX_STACK_ANY>(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices)
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
{
auto const stype = get_inst(_startinst_id)->type;
auto const schar = get_inst(_startinst_id)->u1.c;
Expand All @@ -423,7 +419,7 @@ __device__ inline int32_t reprog_device::call_regexec<RX_STACK_ANY>(
relist* list2 = new (listmem + relists_size) relist(static_cast<int16_t>(_insts_count));

reljunk jnk(list1, list2, stype, schar);
return regexec(dstr, jnk, begin, end, indices);
return regexec(dstr, jnk, begin, end, group_id);
}

} // namespace detail
Expand Down
12 changes: 4 additions & 8 deletions cpp/src/strings/replace/backref_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,35 +111,31 @@ std::unique_ptr<column> replace_with_backrefs(

// create child columns
auto [offsets, chars] = [&] {
rmm::device_uvector<string_index_pair> indices(strings.size() * d_prog->group_counts(), stream);
cudf::detail::device_2dspan<string_index_pair> d_indices(
indices.data(), strings.size(), d_prog->group_counts());

if (regex_insts <= RX_SMALL_INSTS) {
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_SMALL>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()},
strings.size(),
stream,
mr);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_MEDIUM>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()},
strings.size(),
stream,
mr);
} else if (regex_insts <= RX_LARGE_INSTS) {
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_LARGE>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()},
strings.size(),
stream,
mr);
} else {
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_ANY>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()},
strings.size(),
stream,
mr);
Expand Down
Loading