Skip to content

Commit

Permalink
Refactor cudf::strings::count_re API to use count_matches utility (#1…
Browse files Browse the repository at this point in the history
…0580)

Refactors the `cudf::strings::detail::count_re` function to reuse the `cudf::strings::detail::count_matches` utility created for findall, extractall, and split. The kernel code was identical with the only main difference the size of the output column. So the output size was added as a parameter to `count_matches` and the callers appropriately updated.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #10580
  • Loading branch information
davidwendt authored Apr 11, 2022
1 parent bf4ffc9 commit df6bd3c
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 91 deletions.
93 changes: 21 additions & 72 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <strings/count_matches.hpp>
#include <strings/regex/dispatcher.hpp>
#include <strings/regex/regex.cuh>
#include <strings/utilities.hpp>
Expand Down Expand Up @@ -114,6 +115,26 @@ std::unique_ptr<column> matches_re(
return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr);
}

std::unique_ptr<column> count_re(strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// compile regex into device object
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

auto const d_strings = column_device_view::create(input.parent(), stream);

auto result = count_matches(*d_strings, *d_prog, input.size(), stream, mr);
if (input.has_nulls()) {
result->set_null_mask(cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count());
}
return result;
}

} // namespace detail

// external APIs
Expand All @@ -136,78 +157,6 @@ std::unique_ptr<column> matches_re(strings_column_view const& strings,
return detail::matches_re(strings, pattern, flags, rmm::cuda_stream_default, mr);
}

namespace detail {
namespace {
/**
* @brief This counts the number of times the regex pattern matches in each string.
*/
template <int stack_size>
struct count_fn {
reprog_device prog;
column_device_view const d_strings;

__device__ int32_t operator()(unsigned int idx)
{
if (d_strings.is_null(idx)) return 0;
auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();
int32_t find_count = 0;
int32_t begin = 0;
while (begin < nchars) {
auto end = static_cast<int32_t>(nchars);
if (prog.find<stack_size>(idx, d_str, begin, end) <= 0) break;
++find_count;
begin = end > begin ? end : begin + 1;
}
return find_count;
}
};

struct count_dispatch_fn {
reprog_device d_prog;

template <int stack_size>
std::unique_ptr<column> operator()(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto results = make_numeric_column(data_type{type_id::INT32},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);

auto const d_strings = column_device_view::create(input.parent(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<int32_t>(),
count_fn<stack_size>{d_prog, *d_strings});
return results;
}
};

} // namespace

std::unique_ptr<column> count_re(
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
// compile regex into device object
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

return regex_dispatcher(*d_prog, count_dispatch_fn{*d_prog}, input, stream, mr);
}

} // namespace detail

// external API

std::unique_ptr<column> count_re(strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
Expand Down
26 changes: 12 additions & 14 deletions cpp/src/strings/count_matches.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@ struct count_matches_fn {
__device__ size_type operator()(size_type idx)
{
if (d_strings.is_null(idx)) { return 0; }
size_type count = 0;
auto const d_str = d_strings.element<string_view>(idx);
size_type count = 0;
auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();

int32_t begin = 0;
int32_t end = d_str.length();
int32_t end = nchars;
while ((begin < end) && (prog.find<stack_size>(idx, d_str, begin, end) > 0)) {
++count;
begin = end + (begin == end);
end = d_str.length();
end = nchars;
}
return count;
}
Expand All @@ -62,11 +63,14 @@ struct count_dispatch_fn {

template <int stack_size>
std::unique_ptr<column> operator()(column_device_view const& d_strings,
size_type output_size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
assert(output_size >= d_strings.size() and "Unexpected output size");

auto results = make_numeric_column(
data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr);
data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
Expand All @@ -80,21 +84,15 @@ struct count_dispatch_fn {
} // namespace

/**
* @brief Returns a column of regex match counts for each string in the given column.
*
* A null entry will result in a zero count for that output row.
*
* @param d_strings Device view of the input strings column.
* @param d_prog Regex instance to evaluate on each string.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @copydoc cudf::strings::detail::count_matches
*/
std::unique_ptr<column> count_matches(column_device_view const& d_strings,
reprog_device const& d_prog,
size_type output_size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, stream, mr);
return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, output_size, stream, mr);
}

} // namespace detail
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/strings/count_matches.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ class reprog_device;
*
* @param d_strings Device view of the input strings column.
* @param d_prog Regex instance to evaluate on each string.
* @param output_size Number of rows for the output column.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr<column> count_matches(
column_device_view const& d_strings,
reprog_device const& d_prog,
size_type output_size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/strings/extract/extract_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ std::unique_ptr<column> extract_all_record(

// Get the match counts for each string.
// This column will become the output lists child offsets column.
auto offsets = count_matches(*d_strings, *d_prog, stream, mr);
auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr);
auto d_offsets = offsets->mutable_view().data<offset_type>();

// Compute null output rows
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/strings/search/findall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ std::unique_ptr<table> findall(strings_column_view const& input,
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);

auto const d_strings = column_device_view::create(input.parent(), stream);
auto find_counts = count_matches(*d_strings, *d_prog, stream);
auto find_counts = count_matches(*d_strings, *d_prog, strings_count + 1, stream);
auto d_find_counts = find_counts->view().data<size_type>();

size_type const columns_count = thrust::reduce(
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/strings/search/findall_record.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ std::unique_ptr<column> findall_record(
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);

// Create lists offsets column
auto offsets = count_matches(*d_strings, *d_prog, stream, mr);
auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr);
auto d_offsets = offsets->mutable_view().data<offset_type>();

// Convert counts into offsets
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/strings/split/split_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ std::unique_ptr<table> split_re(strings_column_view const& input,
auto d_strings = column_device_view::create(input.parent(), stream);

// count the number of delimiters matched in each string
auto offsets = count_matches(*d_strings, *d_prog, stream, rmm::mr::get_current_device_resource());
auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream);
auto offsets_view = offsets->mutable_view();
auto d_offsets = offsets_view.data<offset_type>();

Expand Down Expand Up @@ -287,7 +287,7 @@ std::unique_ptr<column> split_record_re(strings_column_view const& input,
auto d_strings = column_device_view::create(input.parent(), stream);

// count the number of delimiters matched in each string
auto offsets = count_matches(*d_strings, *d_prog, stream, mr);
auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr);
auto offsets_view = offsets->mutable_view();

// get the split tokens from the input column; this also converts the counts into offsets
Expand Down

0 comments on commit df6bd3c

Please sign in to comment.