Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a dispatcher for invoking regex kernel functions #10349

Merged
merged 9 commits into from
Mar 2, 2022
200 changes: 77 additions & 123 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,10 @@
* limitations under the License.
*/

#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>
#include <strings/utilities.hpp>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
Expand All @@ -23,123 +27,90 @@
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <strings/regex/regex.cuh>
#include <strings/utilities.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {

namespace {
/**
* @brief This functor handles both contains_re and match_re to minimize the number
* of regex calls to find() to be inlined greatly reducing compile time.
*
* The stack is used to keep progress on evaluating the regex instructions on each string.
* So the size of the stack is in proportion to the number of instructions in the given regex
* pattern.
*
* There are three call types based on the number of regex instructions in the given pattern.
* Small to medium instruction lengths can use the stack effectively though smaller executes faster.
* Longer patterns require global memory.
cwharris marked this conversation as resolved.
Show resolved Hide resolved
*/
template <int stack_size>
struct contains_fn {
reprog_device prog;
column_device_view d_strings;
bool bmatch{false}; // do not make this a template parameter to keep compile times down
column_device_view const d_strings;
bool const bmatch; // do not make this a template parameter to keep compile times down
davidwendt marked this conversation as resolved.
Show resolved Hide resolved

__device__ bool operator()(size_type idx)
{
if (d_strings.is_null(idx)) return false;
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = bmatch ? 1 // match only the beginning of the string;
: -1; // this handles empty strings too
auto const d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = bmatch ? 1 // match only the beginning of the string;
: -1; // this handles empty strings too
bdice marked this conversation as resolved.
Show resolved Hide resolved
return static_cast<bool>(prog.find<stack_size>(idx, d_str, begin, end));
}
};

//
std::unique_ptr<column> contains_util(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
bool beginning_only = false,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_column = *strings_column;

// compile regex into device object
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;

// create the output column
auto results = make_numeric_column(data_type{type_id::BOOL8},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
stream,
mr);
auto d_results = results->mutable_view().data<bool>();
struct contains_dispatch_fn {
reprog_device d_prog;
bool const beginning_only{false};

// fill the output column
int regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_SMALL>{d_prog, d_column, beginning_only});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_MEDIUM>{d_prog, d_column, beginning_only});
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_LARGE>{d_prog, d_column, beginning_only});
else
template <int stack_size>
std::unique_ptr<column> operator()(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);

auto const d_strings = column_device_view::create(input.parent(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_ANY>{d_prog, d_column, beginning_only});

results->set_null_count(strings.null_count());
return results;
}
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<bool>(),
contains_fn<stack_size>{d_prog, *d_strings, beginning_only});
return results;
}
};

} // namespace

std::unique_ptr<column> contains_re(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
return contains_util(strings, pattern, flags, false, stream, mr);
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, false}, input, stream, mr);
}

std::unique_ptr<column> matches_re(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
return contains_util(strings, pattern, flags, true, stream, mr);
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr);
}

} // namespace detail
Expand Down Expand Up @@ -172,12 +143,12 @@ namespace {
template <int stack_size>
struct count_fn {
reprog_device prog;
column_device_view d_strings;
column_device_view const d_strings;

__device__ int32_t operator()(unsigned int idx)
{
if (d_strings.is_null(idx)) return 0;
string_view d_str = d_strings.element<string_view>(idx);
auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();
int32_t find_count = 0;
int32_t begin = 0;
Expand All @@ -191,62 +162,45 @@ struct count_fn {
}
};

struct count_dispatch_fn {
reprog_device d_prog;

template <int stack_size>
std::unique_ptr<column> operator()(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(data_type{type_id::INT32},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);

auto const d_strings = column_device_view::create(input.parent(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<int32_t>(),
count_fn<stack_size>{d_prog, *d_strings});
return results;
}
};

} // namespace

std::unique_ptr<column> count_re(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_column = *strings_column;

// compile regex into device object
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;

// create the output column
auto results = make_numeric_column(data_type{type_id::INT32},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
stream,
mr);
auto d_results = results->mutable_view().data<int32_t>();

// fill the output column
int regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_SMALL>{d_prog, d_column});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_MEDIUM>{d_prog, d_column});
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_LARGE>{d_prog, d_column});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_ANY>{d_prog, d_column});
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

results->set_null_count(strings.null_count());
return results;
return regex_dispatcher(*d_prog, count_dispatch_fn{*d_prog}, input, stream, mr);
}

} // namespace detail
Expand Down
48 changes: 23 additions & 25 deletions cpp/src/strings/count_matches.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <strings/count_matches.hpp>
#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>

#include <cudf/column/column_device_view.cuh>
Expand Down Expand Up @@ -54,6 +55,27 @@ struct count_matches_fn {
return count;
}
};

struct count_dispatch_fn {
reprog_device d_prog;

template <int stack_size>
std::unique_ptr<column> operator()(column_device_view const& d_strings,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(
data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(d_strings.size()),
results->mutable_view().data<int32_t>(),
count_matches_fn<stack_size>{d_strings, d_prog});
return results;
}
};

} // namespace

/**
Expand All @@ -71,31 +93,7 @@ std::unique_ptr<column> count_matches(column_device_view const& d_strings,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// Create output column
auto counts = make_numeric_column(
data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr);
auto d_counts = counts->mutable_view().data<offset_type>();

auto begin = thrust::make_counting_iterator<size_type>(0);
auto end = thrust::make_counting_iterator<size_type>(d_strings.size());

// Count matches
auto const regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS) {
count_matches_fn<RX_STACK_SMALL> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
count_matches_fn<RX_STACK_MEDIUM> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_LARGE_INSTS) {
count_matches_fn<RX_STACK_LARGE> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else {
count_matches_fn<RX_STACK_ANY> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
}

return counts;
return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, stream, mr);
}

} // namespace detail
Expand Down
Loading