Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change stack-based regex state data to use global memory #10600

Merged
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2f50daf
Cleanup libcudf strings regex classes
davidwendt Apr 1, 2022
996b369
fix merge conflicts
davidwendt Apr 1, 2022
0f257bf
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 4, 2022
82b462b
change idx parameter to id
davidwendt Apr 4, 2022
25ab493
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 4, 2022
37641f7
change std::string to std::string_view in recomp::create_from
davidwendt Apr 4, 2022
f7ef47c
Change stack-based regex state data to use global memory
davidwendt Apr 5, 2022
f1f598e
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 5, 2022
04f70a8
fix flat reprog-device memory size calc
davidwendt Apr 5, 2022
9c2dc17
fix merge conflicts
davidwendt Apr 5, 2022
83c0e2c
add alignas(16) to reclass decl
davidwendt Apr 5, 2022
5d60d4b
stride merge conflict
davidwendt Apr 5, 2022
eba572e
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 7, 2022
e842438
add restrict keyword to some pointers
davidwendt Apr 7, 2022
50d9f78
use div_round_up for mask size calc
davidwendt Apr 7, 2022
e9b0867
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 7, 2022
7e2272b
fix merge conflict
davidwendt Apr 7, 2022
0652d30
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 8, 2022
00ef10d
fix merge conflict
davidwendt Apr 8, 2022
5261da9
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 8, 2022
8f7adff
Merge branch 'regex-classes-cleanup' into regex-global-memory-state
davidwendt Apr 8, 2022
3ee2f7b
Merge branch 'branch-22.06' into regex-classes-cleanup
davidwendt Apr 11, 2022
d271884
fix merge conflicts
davidwendt Apr 11, 2022
7958fde
fix merge conflicts
davidwendt Apr 14, 2022
73a65cc
re-fix merge conflicts
davidwendt Apr 14, 2022
af84a04
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 16, 2022
e9fb7fd
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 19, 2022
56a67b2
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 20, 2022
7bdaef1
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 21, 2022
0f4d275
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 21, 2022
3c8e535
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 25, 2022
d5187b8
clamp to min_rows in compute_strided_working_memory
davidwendt Apr 27, 2022
8f01ec0
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 27, 2022
4c61dc0
change max_size to requested_max_size
davidwendt Apr 27, 2022
367e264
name the block-size value
davidwendt Apr 28, 2022
6ddcaa1
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt Apr 28, 2022
7d4e18a
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt May 2, 2022
0d5720d
change contains_util to contains_impl
davidwendt May 2, 2022
25def14
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt May 3, 2022
25df987
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt May 5, 2022
c2d6b05
Merge branch 'branch-22.06' into regex-global-memory-state
davidwendt May 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 45 additions & 57 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
*/

#include <strings/count_matches.hpp>
#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>
#include <strings/utilities.hpp>
#include <strings/regex/utilities.cuh>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
Expand All @@ -27,65 +25,61 @@
#include <cudf/strings/contains.hpp>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {

namespace {
/**
* @brief This functor handles both contains_re and match_re to minimize the number
* of regex calls to find() to be inlined greatly reducing compile time.
* @brief This functor handles both contains_re and match_re to regex-match a pattern
* to each string in a column.
*/
template <int stack_size>
struct contains_fn {
reprog_device prog;
column_device_view const d_strings;
bool const beginning_only; // do not make this a template parameter to keep compile times down
bool const beginning_only;

__device__ bool operator()(size_type idx)
__device__ bool operator()(size_type const idx,
reprog_device const prog,
int32_t const thread_idx)
{
if (d_strings.is_null(idx)) return false;
auto const d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = beginning_only ? 1 // match only the beginning of the string;
: -1; // match anywhere in the string
return static_cast<bool>(prog.find<stack_size>(idx, d_str, begin, end));

size_type begin = 0;
size_type end = beginning_only ? 1 // match only the beginning of the string;
: -1; // match anywhere in the string
return static_cast<bool>(prog.find(thread_idx, d_str, begin, end));
}
};

struct contains_dispatch_fn {
reprog_device d_prog;
bool const beginning_only;
std::unique_ptr<column> contains_util(strings_column_view const& input,
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
std::string const& pattern,
regex_flags const flags,
bool const beginning_only,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
if (input.is_empty()) { return results; }

template <int stack_size>
std::unique_ptr<column> operator()(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);

auto const d_strings = column_device_view::create(input.parent(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<bool>(),
contains_fn<stack_size>{d_prog, *d_strings, beginning_only});
return results;
}
};
auto d_prog = reprog_device::create(pattern, flags, stream);

auto d_results = results->mutable_view().data<bool>();
auto const d_strings = column_device_view::create(input.parent(), stream);

launch_transform_kernel(
contains_fn{*d_strings, beginning_only}, *d_prog, d_results, input.size(), stream);

return results;
}

} // namespace

Expand All @@ -96,10 +90,7 @@ std::unique_ptr<column> contains_re(
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, false}, input, stream, mr);
return contains_util(input, pattern, flags, false, stream, mr);
}

std::unique_ptr<column> matches_re(
Expand All @@ -109,21 +100,18 @@ std::unique_ptr<column> matches_re(
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr);
return contains_util(input, pattern, flags, true, stream, mr);
}

std::unique_ptr<column> count_re(strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
std::unique_ptr<column> count_re(
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
// compile regex into device object
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);
auto d_prog = reprog_device::create(pattern, flags, stream);

auto const d_strings = column_device_view::create(input.parent(), stream);

Expand Down
69 changes: 24 additions & 45 deletions cpp/src/strings/count_matches.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,35 @@
*/

#include <strings/count_matches.hpp>
#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>
#include <strings/regex/utilities.cuh>

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/strings/string_view.cuh>

#include <rmm/exec_policy.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {

namespace {
/**
* @brief Functor counts the total matches to the given regex in each string.
* @brief Kernel counts the total matches for the given regex in each string.
*/
template <int stack_size>
struct count_matches_fn {
struct count_fn {
column_device_view const d_strings;
reprog_device prog;

__device__ size_type operator()(size_type idx)
__device__ int32_t operator()(size_type const idx,
reprog_device const prog,
int32_t const thread_idx)
{
if (d_strings.is_null(idx)) { return 0; }
size_type count = 0;
if (d_strings.is_null(idx)) return 0;
auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();
int32_t count = 0;

int32_t begin = 0;
int32_t end = nchars;
while ((begin < end) && (prog.find<stack_size>(idx, d_str, begin, end) > 0)) {
size_type begin = 0;
size_type end = nchars;
while ((begin < end) && (prog.find(thread_idx, d_str, begin, end) > 0)) {
++count;
begin = end + (begin == end);
end = nchars;
Expand All @@ -58,41 +52,26 @@ struct count_matches_fn {
}
};

struct count_dispatch_fn {
reprog_device d_prog;

template <int stack_size>
std::unique_ptr<column> operator()(column_device_view const& d_strings,
size_type output_size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
assert(output_size >= d_strings.size() and "Unexpected output size");

auto results = make_numeric_column(
data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(d_strings.size()),
results->mutable_view().data<int32_t>(),
count_matches_fn<stack_size>{d_strings, d_prog});
return results;
}
};

} // namespace

/**
* @copydoc cudf::strings::detail::count_matches
*/
std::unique_ptr<column> count_matches(column_device_view const& d_strings,
reprog_device const& d_prog,
reprog_device& d_prog,
size_type output_size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, output_size, stream, mr);
assert(output_size >= d_strings.size() and "Unexpected output size");

auto results = make_numeric_column(
data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr);

if (d_strings.size() == 0) return results;

auto d_results = results->mutable_view().data<int32_t>();

launch_transform_kernel(count_fn{d_strings}, d_prog, d_results, d_strings.size(), stream);

return results;
}

} // namespace detail
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/strings/count_matches.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ class reprog_device;
* @param output_size Number of rows for the output column.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return Integer column of match counts
*/
std::unique_ptr<column> count_matches(
column_device_view const& d_strings,
reprog_device const& d_prog,
reprog_device& d_prog,
size_type output_size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
Expand Down
57 changes: 19 additions & 38 deletions cpp/src/strings/extract/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
* limitations under the License.
*/

#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>
#include <strings/utilities.hpp>
#include <strings/regex/utilities.cuh>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
Expand All @@ -31,7 +29,7 @@
#include <rmm/cuda_stream_view.hpp>

#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/fill.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/pair.h>
Expand All @@ -47,28 +45,26 @@ using string_index_pair = thrust::pair<const char*, size_type>;
/**
* @brief This functor handles extracting strings by applying the compiled regex pattern
* and creating string_index_pairs for all the substrings.
*
* @tparam stack_size Correlates to the regex instructions state to maintain for each string.
* Each instruction requires a fixed amount of overhead data.
*/
template <int stack_size>
struct extract_fn {
reprog_device prog;
column_device_view const d_strings;
cudf::detail::device_2dspan<string_index_pair> d_indices;

__device__ void operator()(size_type idx)
__device__ void operator()(size_type const idx,
reprog_device const d_prog,
int32_t const prog_idx)
{
auto const groups = prog.group_counts();
auto const groups = d_prog.group_counts();
auto d_output = d_indices[idx];

if (d_strings.is_valid(idx)) {
auto const d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if (prog.find<stack_size>(idx, d_str, begin, end) > 0) {

size_type begin = 0;
size_type end = -1; // handles empty strings automatically
if (d_prog.find(prog_idx, d_str, begin, end) > 0) {
for (auto col_idx = 0; col_idx < groups; ++col_idx) {
auto const extracted = prog.extract<stack_size>(idx, d_str, begin, end, col_idx);
auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, col_idx);
d_output[col_idx] = [&] {
if (!extracted) return string_index_pair{nullptr, 0};
auto const offset = d_str.byte_offset((*extracted).first);
Expand All @@ -85,33 +81,17 @@ struct extract_fn {
}
};

struct extract_dispatch_fn {
reprog_device d_prog;

template <int stack_size>
void operator()(column_device_view const& d_strings,
cudf::detail::device_2dspan<string_index_pair>& d_indices,
rmm::cuda_stream_view stream)
{
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
d_strings.size(),
extract_fn<stack_size>{d_prog, d_strings, d_indices});
}
};
} // namespace

//
std::unique_ptr<table> extract(
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
std::unique_ptr<table> extract(strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// compile regex into device object
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);
auto d_prog = reprog_device::create(pattern, flags, stream);

auto const groups = d_prog->group_counts();
CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern");
Expand All @@ -121,7 +101,8 @@ std::unique_ptr<table> extract(
cudf::detail::device_2dspan<string_index_pair>(indices.data(), input.size(), groups);

auto const d_strings = column_device_view::create(input.parent(), stream);
regex_dispatcher(*d_prog, extract_dispatch_fn{*d_prog}, *d_strings, d_indices, stream);

launch_for_each_kernel(extract_fn{*d_strings, d_indices}, *d_prog, input.size(), stream);

// build a result column for each group
std::vector<std::unique_ptr<column>> results(groups);
Expand Down
Loading