Skip to content

Commit

Permalink
Create a dispatcher for invoking regex kernel functions (#10349)
Browse files Browse the repository at this point in the history
Closes #10138 

Refactor the various regex function calls to use a dispatcher instead of if-else clauses. Each regex call currently requires different stack sizes (and later launch parameters). Changes to these parameters are sometimes difficult to coordinate since they usually need to be duplicated across about 10 APIs that are currently using regex calls. The new `regex_dispatcher` makes calling these much cleaner and easier to maintain. This will be helpful when experimenting with possibly using different launch parameters.

No functions have changed. Mostly this is a refactoring and cleanup effort. The `findall.cu` was also recoded to use the new `count_matches` utility.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Christopher Harris (https://github.com/cwharris)
  - Bradley Dice (https://github.com/bdice)
  - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu)

URL: #10349
  • Loading branch information
davidwendt authored Mar 2, 2022
1 parent 78b316c commit 1217f24
Show file tree
Hide file tree
Showing 11 changed files with 452 additions and 507 deletions.
200 changes: 77 additions & 123 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,10 @@
* limitations under the License.
*/

#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>
#include <strings/utilities.hpp>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
Expand All @@ -23,123 +27,90 @@
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <strings/regex/regex.cuh>
#include <strings/utilities.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {

namespace {
/**
* @brief This functor handles both contains_re and match_re to minimize the number
* of regex calls to find() to be inlined greatly reducing compile time.
*
* The stack is used to keep progress on evaluating the regex instructions on each string.
* So the size of the stack is in proportion to the number of instructions in the given regex
* pattern.
*
* There are three call types based on the number of regex instructions in the given pattern.
* Small to medium instruction lengths can use the stack effectively though smaller executes faster.
* Longer patterns require global memory.
*/
template <int stack_size>
struct contains_fn {
reprog_device prog;
column_device_view d_strings;
bool bmatch{false}; // do not make this a template parameter to keep compile times down
column_device_view const d_strings;
bool const beginning_only; // do not make this a template parameter to keep compile times down

__device__ bool operator()(size_type idx)
{
if (d_strings.is_null(idx)) return false;
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = bmatch ? 1 // match only the beginning of the string;
: -1; // this handles empty strings too
auto const d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = beginning_only ? 1 // match only the beginning of the string;
: -1; // match anywhere in the string
return static_cast<bool>(prog.find<stack_size>(idx, d_str, begin, end));
}
};

//
std::unique_ptr<column> contains_util(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
bool beginning_only = false,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_column = *strings_column;

// compile regex into device object
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;

// create the output column
auto results = make_numeric_column(data_type{type_id::BOOL8},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
stream,
mr);
auto d_results = results->mutable_view().data<bool>();
struct contains_dispatch_fn {
reprog_device d_prog;
bool const beginning_only;

// fill the output column
int regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_SMALL>{d_prog, d_column, beginning_only});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_MEDIUM>{d_prog, d_column, beginning_only});
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_LARGE>{d_prog, d_column, beginning_only});
else
template <int stack_size>
std::unique_ptr<column> operator()(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);

auto const d_strings = column_device_view::create(input.parent(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_ANY>{d_prog, d_column, beginning_only});

results->set_null_count(strings.null_count());
return results;
}
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<bool>(),
contains_fn<stack_size>{d_prog, *d_strings, beginning_only});
return results;
}
};

} // namespace

std::unique_ptr<column> contains_re(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
return contains_util(strings, pattern, flags, false, stream, mr);
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, false}, input, stream, mr);
}

std::unique_ptr<column> matches_re(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
return contains_util(strings, pattern, flags, true, stream, mr);
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr);
}

} // namespace detail
Expand Down Expand Up @@ -172,12 +143,12 @@ namespace {
template <int stack_size>
struct count_fn {
reprog_device prog;
column_device_view d_strings;
column_device_view const d_strings;

__device__ int32_t operator()(unsigned int idx)
{
if (d_strings.is_null(idx)) return 0;
string_view d_str = d_strings.element<string_view>(idx);
auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();
int32_t find_count = 0;
int32_t begin = 0;
Expand All @@ -191,62 +162,45 @@ struct count_fn {
}
};

struct count_dispatch_fn {
reprog_device d_prog;

template <int stack_size>
std::unique_ptr<column> operator()(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(data_type{type_id::INT32},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);

auto const d_strings = column_device_view::create(input.parent(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<int32_t>(),
count_fn<stack_size>{d_prog, *d_strings});
return results;
}
};

} // namespace

std::unique_ptr<column> count_re(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_column = *strings_column;

// compile regex into device object
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;

// create the output column
auto results = make_numeric_column(data_type{type_id::INT32},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
stream,
mr);
auto d_results = results->mutable_view().data<int32_t>();

// fill the output column
int regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_SMALL>{d_prog, d_column});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_MEDIUM>{d_prog, d_column});
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_LARGE>{d_prog, d_column});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_ANY>{d_prog, d_column});
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

results->set_null_count(strings.null_count());
return results;
return regex_dispatcher(*d_prog, count_dispatch_fn{*d_prog}, input, stream, mr);
}

} // namespace detail
Expand Down
48 changes: 23 additions & 25 deletions cpp/src/strings/count_matches.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <strings/count_matches.hpp>
#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>

#include <cudf/column/column_device_view.cuh>
Expand Down Expand Up @@ -54,6 +55,27 @@ struct count_matches_fn {
return count;
}
};

struct count_dispatch_fn {
reprog_device d_prog;

template <int stack_size>
std::unique_ptr<column> operator()(column_device_view const& d_strings,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(
data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(d_strings.size()),
results->mutable_view().data<int32_t>(),
count_matches_fn<stack_size>{d_strings, d_prog});
return results;
}
};

} // namespace

/**
Expand All @@ -71,31 +93,7 @@ std::unique_ptr<column> count_matches(column_device_view const& d_strings,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// Create output column
auto counts = make_numeric_column(
data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr);
auto d_counts = counts->mutable_view().data<offset_type>();

auto begin = thrust::make_counting_iterator<size_type>(0);
auto end = thrust::make_counting_iterator<size_type>(d_strings.size());

// Count matches
auto const regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS) {
count_matches_fn<RX_STACK_SMALL> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
count_matches_fn<RX_STACK_MEDIUM> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_LARGE_INSTS) {
count_matches_fn<RX_STACK_LARGE> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else {
count_matches_fn<RX_STACK_ANY> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
}

return counts;
return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, stream, mr);
}

} // namespace detail
Expand Down
Loading

0 comments on commit 1217f24

Please sign in to comment.