Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify reprog_device::extract to return groups in a single pass #8460

Merged
merged 11 commits into from
Jun 18, 2021
2 changes: 0 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,6 @@ add_library(cudf
src/strings/regex/regexec.cu
src/strings/repeat_strings.cu
src/strings/replace/backref_re.cu
src/strings/replace/backref_re_large.cu
src/strings/replace/backref_re_medium.cu
src/strings/replace/multi_re.cu
src/strings/replace/replace.cu
src/strings/replace/replace_re.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace {
* Small to medium instruction lengths can use the stack effectively though smaller executes faster.
* Longer patterns require global memory.
*/
template <size_t stack_size>
template <int stack_size>
struct contains_fn {
reprog_device prog;
column_device_view d_strings;
Expand Down Expand Up @@ -163,7 +163,7 @@ namespace {
/**
* @brief This counts the number of times the regex pattern matches in each string.
*/
template <size_t stack_size>
template <int stack_size>
struct count_fn {
reprog_device prog;
column_device_view d_strings;
Expand Down
91 changes: 48 additions & 43 deletions cpp/src/strings/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/null_mask.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/string_view.cuh>
Expand All @@ -32,7 +34,6 @@
namespace cudf {
namespace strings {
namespace detail {
using string_index_pair = thrust::pair<const char*, size_type>;

namespace {
/**
Expand All @@ -42,26 +43,27 @@ namespace {
* @tparam stack_size Correlates to the regex instructions state to maintain for each string.
* Each instruction requires a fixed amount of overhead data.
*/
template <size_t stack_size>
template <int stack_size>
struct extract_fn {
reprog_device prog;
column_device_view d_strings;
size_type column_index;
string_index_pair* d_results;

__device__ string_index_pair operator()(size_type idx)
__device__ void operator()(size_type idx)
{
if (d_strings.is_null(idx)) return string_index_pair{nullptr, 0};
string_view d_str = d_strings.element<string_view>(idx);
string_index_pair result{nullptr, 0};
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if ((prog.find<stack_size>(idx, d_str, begin, end) > 0) &&
(prog.extract<stack_size>(idx, d_str, begin, end, column_index) > 0)) {
auto offset = d_str.byte_offset(begin);
// build index-pair
result = string_index_pair{d_str.data() + offset, d_str.byte_offset(end) - offset};
auto groups = prog.group_counts();
auto d_output = d_results + (idx * groups);
if (d_strings.is_valid(idx)) {
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if ((prog.find<stack_size>(idx, d_str, begin, end) > 0) &&
prog.extract<stack_size>(idx, d_str, begin, end, d_output)) {
return;
}
}
return result;
// fill output with null entries
for (int i = 0; i < groups; ++i) { d_output[i] = string_index_pair{nullptr, 0}; }
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
}
};

Expand Down Expand Up @@ -89,36 +91,39 @@ std::unique_ptr<table> extract(
std::vector<std::unique_ptr<column>> results;
auto regex_insts = d_prog.insts_counts();

rmm::device_uvector<string_index_pair> indices(strings_count * groups, stream);

if (regex_insts <= RX_SMALL_INSTS)
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, indices.data()});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, indices.data()});
else if (regex_insts <= RX_LARGE_INSTS)
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, indices.data()});
else // supports any number of instructions
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, indices.data()});

for (int32_t column_index = 0; column_index < groups; ++column_index) {
rmm::device_uvector<string_index_pair> indices(strings_count, stream);

if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, column_index});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, column_index});
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, column_index});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, column_index});

results.emplace_back(make_strings_column(indices, stream, mr));
auto indices_itr = thrust::make_permutation_iterator(
indices.begin(),
thrust::make_transform_iterator(thrust::make_counting_iterator<size_type>(0),
[column_index, groups] __device__(size_type idx) {
return (idx * groups) + column_index;
}));
results.emplace_back(make_strings_column(indices_itr, indices_itr + strings_count, stream, mr));
}

return std::make_unique<table>(std::move(results));
}

Expand Down
35 changes: 25 additions & 10 deletions cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

#include <strings/regex/regcomp.h>

#include <cudf/types.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <thrust/pair.h>

#include <functional>
#include <memory>

Expand All @@ -33,6 +37,8 @@ struct reljunk;
struct reinst;
class reprog;

using string_index_pair = thrust::pair<const char*, cudf::size_type>;

constexpr int32_t RX_STACK_SMALL = 112; ///< fastest stack size
constexpr int32_t RX_STACK_MEDIUM = 1104; ///< faster stack size
constexpr int32_t RX_STACK_LARGE = 10128; ///< fast stack size
Expand Down Expand Up @@ -99,6 +105,7 @@ class reprog_device {
const uint8_t* cp_flags,
int32_t strings_count,
rmm::cuda_stream_view stream);

/**
* @brief Called automatically by the unique_ptr returned from create().
*/
Expand Down Expand Up @@ -157,21 +164,23 @@ class reprog_device {
* @brief Does an extract evaluation using the compiled expression on the given string.
*
* This will find a specific match within the string when more than match occurs.
* The find() function should be called first to locate the begin/end bounds of the
* the matched section.
*
* @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`.
* @param idx The string index used for mapping the state memory for this string in global memory
* (if necessary).
* @param d_str The string to search.
* @param[in,out] begin Position index to begin the search. If found, returns the position found
* @param begin Position index to begin the search. If found, returns the position found
* in the string.
* @param[in,out] end Position index to end the search. If found, returns the last position
* @param end Position index to end the search. If found, returns the last position
* matching in the string.
* @param group_id The specific instance to return if more than one match is found.
* @return Returns 0 if no match is found.
* @param indices All extracted groups
* @return Returns true if successful.
*/
template <int stack_size>
__device__ inline int32_t extract(
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t group_id);
__device__ inline bool extract(
int32_t idx, string_view const& d_str, int32_t begin, int32_t end, string_index_pair* indices);
davidwendt marked this conversation as resolved.
Show resolved Hide resolved

private:
int32_t _startinst_id, _num_capturing_groups;
Expand All @@ -185,15 +194,21 @@ class reprog_device {
/**
* @brief Executes the regex pattern on the given string.
*/
__device__ inline int32_t regexec(
string_view const& d_str, reljunk& jnk, int32_t& begin, int32_t& end, int32_t groupid = 0);
__device__ inline int32_t regexec(string_view const& d_str,
reljunk& jnk,
int32_t& begin,
int32_t& end,
string_index_pair* indices = nullptr);

/**
* @brief Utility wrapper to setup state memory structures for calling regexec
*/
template <int stack_size>
__device__ inline int32_t call_regexec(
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t groupid = 0);
__device__ inline int32_t call_regexec(int32_t idx,
string_view const& d_str,
int32_t& begin,
int32_t& end,
string_index_pair* indices = nullptr);

reprog_device(reprog&); // must use create()
};
Expand Down
44 changes: 26 additions & 18 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ __device__ inline int32_t* reprog_device::startinst_ids() const { return _starti
* @return >0 if match found
*/
__device__ inline int32_t reprog_device::regexec(
string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id)
string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, string_index_pair* indices)
{
int32_t match = 0;
auto checkstart = jnk.starttype;
Expand Down Expand Up @@ -229,10 +229,9 @@ __device__ inline int32_t reprog_device::regexec(
}

if (((eos < 0) || (pos < eos)) && match == 0) {
// jnk.list1->activate(startinst_id, pos, 0);
int32_t i = 0;
auto ids = startinst_ids();
while (ids[i] >= 0) jnk.list1->activate(ids[i++], (group_id == 0 ? pos : -1), -1);
while (ids[i] >= 0) jnk.list1->activate(ids[i++], (indices == nullptr ? pos : -1), -1);
}

c = static_cast<char32_t>(pos >= txtlen ? 0 : *itr);
Expand All @@ -257,14 +256,20 @@ __device__ inline int32_t reprog_device::regexec(
case NCCLASS:
case END: id_activate = inst_id; break;
case LBRA:
if (inst->u1.subid == group_id) range.x = pos;
if (indices && inst->u1.subid == _num_capturing_groups) range.x = pos;
id_activate = inst->u2.next_id;
expanded = true;
if (indices) { indices[inst->u1.subid - 1].first = dstr.data() + itr.byte_offset(); }
break;
case RBRA:
if (inst->u1.subid == group_id) range.y = pos;
if (indices && inst->u1.subid == _num_capturing_groups) range.y = pos;
id_activate = inst->u2.next_id;
expanded = true;
if (indices) {
auto const ptr_offset = indices[inst->u1.subid - 1].first - dstr.data();
indices[inst->u1.subid - 1].second =
itr.byte_offset() - static_cast<cudf::size_type>(ptr_offset);
}
break;
case BOL:
if ((pos == 0) ||
Expand Down Expand Up @@ -318,8 +323,9 @@ __device__ inline int32_t reprog_device::regexec(
} while (expanded);

// execute
bool continue_execute = true;
jnk.list2->reset();
for (int16_t i = 0; i < jnk.list1->size; i++) {
for (int16_t i = 0; continue_execute && i < jnk.list1->size; i++) {
int32_t inst_id = static_cast<int32_t>(jnk.list1->inst_ids[i]);
int2& range = jnk.list1->ranges[i];
const reinst* inst = get_inst(inst_id);
Expand All @@ -346,18 +352,21 @@ __device__ inline int32_t reprog_device::regexec(
case END:
match = 1;
begin = range.x;
end = group_id == 0 ? pos : range.y;
goto BreakFor;
end = indices == nullptr ? pos : range.y;

continue_execute = false;
break;
}
if (id_activate >= 0) jnk.list2->activate(id_activate, range.x, range.y);
if (continue_execute && (id_activate >= 0))
jnk.list2->activate(id_activate, range.x, range.y);
}

BreakFor:
++pos;
++itr;
swaplist(jnk.list1, jnk.list2);
checkstart = jnk.list1->size > 0 ? 0 : 1;
} while (c && (jnk.list1->size > 0 || match == 0));

return match;
}

Expand All @@ -373,16 +382,16 @@ __device__ inline int32_t reprog_device::find(int32_t idx,
}

template <int stack_size>
__device__ inline int32_t reprog_device::extract(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
__device__ inline bool reprog_device::extract(
int32_t idx, string_view const& dstr, int32_t begin, int32_t end, string_index_pair* indices)
{
end = begin + 1;
return call_regexec<stack_size>(idx, dstr, begin, end, group_id + 1);
return call_regexec<stack_size>(idx, dstr, begin, end, indices) > 0;
}

template <int stack_size>
__device__ inline int32_t reprog_device::call_regexec(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices)
{
u_char data1[stack_size], data2[stack_size];

Expand All @@ -393,12 +402,12 @@ __device__ inline int32_t reprog_device::call_regexec(
relist list2(static_cast<int16_t>(_insts_count), data2);

reljunk jnk(&list1, &list2, stype, schar);
return regexec(dstr, jnk, begin, end, group_id);
return regexec(dstr, jnk, begin, end, indices);
}

template <>
__device__ inline int32_t reprog_device::call_regexec<RX_STACK_ANY>(
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id)
int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices)
{
auto const stype = get_inst(_startinst_id)->type;
auto const schar = get_inst(_startinst_id)->u1.c;
Expand All @@ -407,12 +416,11 @@ __device__ inline int32_t reprog_device::call_regexec<RX_STACK_ANY>(
u_char* listmem = reinterpret_cast<u_char*>(_relists_mem); // beginning of relist buffer;
listmem += (idx * relists_size * 2); // two relist ptrs in reljunk:

// run ctor on assigned memory buffer
relist* list1 = new (listmem) relist(static_cast<int16_t>(_insts_count));
relist* list2 = new (listmem + relists_size) relist(static_cast<int16_t>(_insts_count));

reljunk jnk(list1, list2, stype, schar);
return regexec(dstr, jnk, begin, end, group_id);
return regexec(dstr, jnk, begin, end, indices);
}

} // namespace detail
Expand Down
Loading