Skip to content

Commit

Permalink
use device_span and device_2dspan instead for extract interface
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwendt committed Jun 16, 2021
1 parent 930bbe4 commit 085b110
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 27 deletions.
29 changes: 17 additions & 12 deletions cpp/src/strings/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
#include <cudf/strings/extract.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <thrust/fill.h>

namespace cudf {
namespace strings {
namespace detail {
Expand All @@ -46,12 +49,12 @@ template <int stack_size>
struct extract_fn {
reprog_device prog;
column_device_view d_strings;
string_index_pair* d_results;
cudf::detail::device_2dspan<string_index_pair> d_indices;

__device__ void operator()(size_type idx)
{
auto groups = prog.group_counts();
auto d_output = d_results + (idx * groups);
auto d_output = d_indices[idx];
if (d_strings.is_valid(idx)) {
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
Expand All @@ -62,7 +65,7 @@ struct extract_fn {
}
}
// fill output with null entries
for (int i = 0; i < groups; ++i) { d_output[i] = string_index_pair{nullptr, 0}; }
thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0});
}
};

Expand All @@ -83,35 +86,37 @@ std::unique_ptr<table> extract(
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;
// extract should include groups
int groups = d_prog.group_counts();
auto const groups = d_prog.group_counts();
CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern");

// build a result column for each group
std::vector<std::unique_ptr<column>> results;
auto regex_insts = d_prog.insts_counts();

rmm::device_uvector<string_index_pair> indices(strings_count * groups, stream);
cudf::detail::device_2dspan<string_index_pair> d_indices(indices.data(), strings_count, groups);

if (regex_insts <= RX_SMALL_INSTS)
if (regex_insts <= RX_SMALL_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, indices.data()});
else if (regex_insts <= RX_MEDIUM_INSTS)
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, indices.data()});
else if (regex_insts <= RX_LARGE_INSTS)
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, indices.data()});
else // supports any number of instructions
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, d_indices});
} else { // supports any number of instructions
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, indices.data()});
extract_fn<RX_STACK_ANY>{d_prog, d_strings, d_indices});
}

for (int32_t column_index = 0; column_index < groups; ++column_index) {
auto indices_itr = thrust::make_permutation_iterator(
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <strings/regex/regcomp.h>

#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

Expand Down Expand Up @@ -179,8 +180,11 @@ class reprog_device {
* @return Returns true if successful.
*/
template <int stack_size>
__device__ inline bool extract(
int32_t idx, string_view const& d_str, int32_t begin, int32_t end, string_index_pair* indices);
__device__ inline bool extract(int32_t idx,
string_view const& d_str,
int32_t begin,
int32_t end,
device_span<string_index_pair> indices);

private:
int32_t _startinst_id, _num_capturing_groups;
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,14 @@ __device__ inline int32_t reprog_device::find(int32_t idx,
}

template <int stack_size>
__device__ inline bool reprog_device::extract(
int32_t idx, string_view const& dstr, int32_t begin, int32_t end, string_index_pair* indices)
__device__ inline bool reprog_device::extract(int32_t idx,
string_view const& dstr,
int32_t begin,
int32_t end,
device_span<string_index_pair> indices)
{
end = begin + 1;
return call_regexec<stack_size>(idx, dstr, begin, end, indices) > 0;
return call_regexec<stack_size>(idx, dstr, begin, end, indices.data()) > 0;
}

template <int stack_size>
Expand Down
14 changes: 6 additions & 8 deletions cpp/src/strings/replace/backref_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,36 +112,34 @@ std::unique_ptr<column> replace_with_backrefs(
// create child columns
auto [offsets, chars] = [&] {
rmm::device_uvector<string_index_pair> indices(strings.size() * d_prog->group_counts(), stream);
cudf::detail::device_2dspan<string_index_pair> d_indices(
indices.data(), strings.size(), d_prog->group_counts());

if (regex_insts <= RX_SMALL_INSTS) {
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_SMALL>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), indices.data()},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
strings.size(),
stream,
mr);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
// return replace_with_backrefs_medium(
// *d_strings, *d_prog, d_repl_template, backrefs, stream, mr);
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_MEDIUM>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), indices.data()},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
strings.size(),
stream,
mr);
} else if (regex_insts <= RX_LARGE_INSTS) {
// return replace_with_backrefs_large(
// *d_strings, *d_prog, d_repl_template, backrefs, stream, mr);
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_LARGE>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), indices.data()},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
strings.size(),
stream,
mr);
} else {
return make_strings_children(
backrefs_fn<BackRefIterator, RX_STACK_ANY>{
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), indices.data()},
*d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices},
strings.size(),
stream,
mr);
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/strings/replace/backref_re.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <cudf/column/column_device_view.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/utilities/span.hpp>

#include <strings/regex/regex.cuh>

#include <rmm/cuda_stream_view.hpp>
Expand All @@ -41,7 +43,7 @@ struct backrefs_fn {
string_view const d_repl; // string replacement template
Iterator backrefs_begin;
Iterator backrefs_end;
string_index_pair* d_indices;
cudf::detail::device_2dspan<string_index_pair> d_indices;
int32_t* d_offsets{};
char* d_chars{};

Expand All @@ -61,7 +63,7 @@ struct backrefs_fn {
size_type end = nchars; // last character position (exclusive)

// working memory for extract on this string
auto d_extracts = d_indices + idx * prog.group_counts();
auto d_extracts = d_indices[idx];

// copy input to output replacing strings as we go
while (prog.find<stack_size>(idx, d_str, begin, end) > 0) // inits the begin/end vars
Expand Down

0 comments on commit 085b110

Please sign in to comment.