From 1217f24d97c3559e15293040ebda7914f00cb25e Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Wed, 2 Mar 2022 10:52:16 -0500 Subject: [PATCH] Create a dispatcher for invoking regex kernel functions (#10349) Closes #10138 Refactor the various regex function calls to use a dispatcher instead of if-else clauses. Each regex call currently requires different stack sizes (and later launch parameters). Changes to these parameters are sometimes difficult to coordinate since they usually need to be duplicated across about 10 APIs that are currently using regex calls. The new `regex_dispatcher` makes calling these much cleaner and easier to maintain. This will be helpful when experimenting with possibly using different launch parameters. No functions have changed. Mostly this is a refactoring and cleanup effort. The `findall.cu` was also recoded to use the new `count_matches` utility. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Christopher Harris (https://github.com/cwharris) - Bradley Dice (https://github.com/bdice) - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu) URL: https://github.com/rapidsai/cudf/pull/10349 --- cpp/src/strings/contains.cu | 200 +++++++++-------------- cpp/src/strings/count_matches.cu | 48 +++--- cpp/src/strings/extract/extract.cu | 62 +++---- cpp/src/strings/extract/extract_all.cu | 54 +++--- cpp/src/strings/regex/dispatcher.hpp | 59 +++++++ cpp/src/strings/replace/backref_re.cu | 87 +++++----- cpp/src/strings/replace/multi_re.cu | 114 ++++++------- cpp/src/strings/replace/replace_re.cu | 83 ++++------ cpp/src/strings/search/findall.cu | 155 ++++++------------ cpp/src/strings/search/findall_record.cu | 54 +++--- cpp/src/strings/split/split_re.cu | 43 +++-- 11 files changed, 452 insertions(+), 507 deletions(-) create mode 100644 cpp/src/strings/regex/dispatcher.hpp diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index efdee65c1f6..23bc5cf2dfe 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,10 @@ * limitations under the License. */ +#include +#include +#include + #include #include #include @@ -23,123 +27,90 @@ #include #include #include -#include -#include #include #include +#include + namespace cudf { namespace strings { namespace detail { + namespace { /** * @brief This functor handles both contains_re and match_re to minimize the number * of regex calls to find() to be inlined greatly reducing compile time. - * - * The stack is used to keep progress on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are three call types based on the number of regex instructions in the given pattern. - * Small to medium instruction lengths can use the stack effectively though smaller executes faster. - * Longer patterns require global memory. */ template struct contains_fn { reprog_device prog; - column_device_view d_strings; - bool bmatch{false}; // do not make this a template parameter to keep compile times down + column_device_view const d_strings; + bool const beginning_only; // do not make this a template parameter to keep compile times down __device__ bool operator()(size_type idx) { if (d_strings.is_null(idx)) return false; - string_view d_str = d_strings.element(idx); - int32_t begin = 0; - int32_t end = bmatch ? 1 // match only the beginning of the string; - : -1; // this handles empty strings too + auto const d_str = d_strings.element(idx); + int32_t begin = 0; + int32_t end = beginning_only ? 1 // match only the beginning of the string; + : -1; // match anywhere in the string return static_cast(prog.find(idx, d_str, begin, end)); } }; -// -std::unique_ptr contains_util( - strings_column_view const& strings, - std::string const& pattern, - regex_flags const flags, - bool beginning_only = false, - rmm::cuda_stream_view stream = rmm::cuda_stream_default, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) -{ - auto strings_count = strings.size(); - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_column = *strings_column; - - // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - - // create the output column - auto results = make_numeric_column(data_type{type_id::BOOL8}, - strings_count, - cudf::detail::copy_bitmask(strings.parent(), stream, mr), - strings.null_count(), - stream, - mr); - auto d_results = results->mutable_view().data(); +struct contains_dispatch_fn { + reprog_device d_prog; + bool const beginning_only; - // fill the output column - int regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - else + template + std::unique_ptr operator()(strings_column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto results = make_numeric_column(data_type{type_id::BOOL8}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + + auto const d_strings = column_device_view::create(input.parent(), stream); thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - - results->set_null_count(strings.null_count()); - return results; -} + thrust::make_counting_iterator(input.size()), + results->mutable_view().data(), + contains_fn{d_prog, *d_strings, beginning_only}); + return results; + } +}; } // namespace std::unique_ptr contains_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - return contains_util(strings, pattern, flags, false, stream, mr); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, false}, input, stream, mr); } std::unique_ptr matches_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - return contains_util(strings, pattern, flags, true, stream, mr); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr); } } // namespace detail @@ -172,12 +143,12 @@ namespace { template struct count_fn { reprog_device prog; - column_device_view d_strings; + column_device_view const d_strings; __device__ int32_t operator()(unsigned int idx) { if (d_strings.is_null(idx)) return 0; - string_view d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); auto const nchars = d_str.length(); int32_t find_count = 0; int32_t begin = 0; @@ -191,62 +162,45 @@ struct count_fn { } }; +struct count_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(strings_column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto results = make_numeric_column(data_type{type_id::INT32}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + + auto const d_strings = column_device_view::create(input.parent(), stream); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.size()), + results->mutable_view().data(), + count_fn{d_prog, *d_strings}); + return results; + } +}; + } // namespace std::unique_ptr count_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto strings_count = strings.size(); - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_column = *strings_column; - // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - - // create the output column - auto results = make_numeric_column(data_type{type_id::INT32}, - strings_count, - cudf::detail::copy_bitmask(strings.parent(), stream, mr), - strings.null_count(), - stream, - mr); - auto d_results = results->mutable_view().data(); - - // fill the output column - int regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); - else - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); - results->set_null_count(strings.null_count()); - return results; + return regex_dispatcher(*d_prog, count_dispatch_fn{*d_prog}, input, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/count_matches.cu b/cpp/src/strings/count_matches.cu index d0a6825666b..ae996cafd2c 100644 --- a/cpp/src/strings/count_matches.cu +++ b/cpp/src/strings/count_matches.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -54,6 +55,27 @@ struct count_matches_fn { return count; } }; + +struct count_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(column_device_view const& d_strings, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto results = make_numeric_column( + data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(d_strings.size()), + results->mutable_view().data(), + count_matches_fn{d_strings, d_prog}); + return results; + } +}; + } // namespace /** @@ -71,31 +93,7 @@ std::unique_ptr count_matches(column_device_view const& d_strings, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - // Create output column - auto counts = make_numeric_column( - data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr); - auto d_counts = counts->mutable_view().data(); - - auto begin = thrust::make_counting_iterator(0); - auto end = thrust::make_counting_iterator(d_strings.size()); - - // Count matches - auto const regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } else if (regex_insts <= RX_LARGE_INSTS) { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } else { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } - - return counts; + return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index a67af9442f0..7394cdac6bb 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -77,53 +78,44 @@ struct extract_fn { thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0}); } }; + +struct extract_dispatch_fn { + reprog_device d_prog; + + template + void operator()(column_device_view const& d_strings, + cudf::detail::device_2dspan& d_indices, + rmm::cuda_stream_view stream) + { + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + extract_fn{d_prog, d_strings, d_indices}); + } +}; } // namespace // std::unique_ptr extract( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto const strings_count = strings.size(); - auto const strings_column = column_device_view::create(strings.parent(), stream); - auto const d_strings = *strings_column; - // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - // extract should include groups - auto const groups = d_prog.group_counts(); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + auto const groups = d_prog->group_counts(); CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern"); - rmm::device_uvector indices(strings_count * groups, stream); - cudf::detail::device_2dspan d_indices(indices.data(), strings_count, groups); + auto indices = rmm::device_uvector(input.size() * groups, stream); + auto d_indices = + cudf::detail::device_2dspan(indices.data(), input.size(), groups); - auto const regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } else if (regex_insts <= RX_LARGE_INSTS) { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } else { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } + auto const d_strings = column_device_view::create(input.parent(), stream); + regex_dispatcher(*d_prog, extract_dispatch_fn{*d_prog}, *d_strings, d_indices, stream); // build a result column for each group std::vector> results(groups); @@ -135,7 +127,7 @@ std::unique_ptr
extract( 0, [column_index, groups] __device__(size_type idx) { return (idx * groups) + column_index; })); - return make_strings_column(indices_itr, indices_itr + strings_count, stream, mr); + return make_strings_column(indices_itr, indices_itr + input.size(), stream, mr); }; std::transform(thrust::make_counting_iterator(0), diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index e27dccb9338..1f1474c777b 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -86,6 +87,28 @@ struct extract_fn { } } }; + +struct extract_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(column_device_view const& d_strings, + size_type total_groups, + offset_type const* d_offsets, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + rmm::device_uvector indices(total_groups, stream); + + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + extract_fn{d_strings, d_prog, d_offsets, indices.data()}); + + return make_strings_column(indices.begin(), indices.end(), stream, mr); + } +}; + } // namespace /** @@ -94,14 +117,14 @@ struct extract_fn { * @param stream CUDA stream used for device memory operations and kernel launches. */ std::unique_ptr extract_all_record( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto const strings_count = strings.size(); - auto const d_strings = column_device_view::create(strings.parent(), stream); + auto const strings_count = input.size(); + auto const d_strings = column_device_view::create(input.parent(), stream); // Compile regex into device object. auto d_prog = @@ -143,29 +166,8 @@ std::unique_ptr extract_all_record( auto const total_groups = cudf::detail::get_value(offsets->view(), strings_count, stream); - // Create an indices vector with the total number of groups that will be extracted. - rmm::device_uvector indices(total_groups, stream); - auto d_indices = indices.data(); - auto begin = thrust::make_counting_iterator(0); - - // Call the extract functor to fill in the indices vector. - auto const regex_insts = d_prog->insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_LARGE_INSTS) { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } - - // Build the child strings column from the indices. - auto strings_output = make_strings_column(indices.begin(), indices.end(), stream, mr); + auto strings_output = regex_dispatcher( + *d_prog, extract_dispatch_fn{*d_prog}, *d_strings, total_groups, d_offsets, stream, mr); // Build the lists column from the offsets and the strings. return make_lists_column(strings_count, diff --git a/cpp/src/strings/regex/dispatcher.hpp b/cpp/src/strings/regex/dispatcher.hpp new file mode 100644 index 00000000000..9ff51d1c979 --- /dev/null +++ b/cpp/src/strings/regex/dispatcher.hpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf { +namespace strings { +namespace detail { + +/** + * The stack is used to keep progress (state) on evaluating the regex instructions on each string. + * So the size of the stack is in proportion to the number of instructions in the given regex + * pattern. + * + * There are four call types based on the number of regex instructions in the given pattern. + * Small, medium, and large instruction counts can use the stack effectively. + * Smaller stack sizes execute faster. + * + * Patterns with instruction counts bigger than large use global memory rather than the stack + * for managing the evaluation state data. + * + * @tparam Functor The functor to invoke with stack size templated value. + * @tparam Ts Parameter types for the functor call. + */ +template +constexpr decltype(auto) regex_dispatcher(reprog_device d_prog, Functor f, Ts&&... args) +{ + auto const num_regex_insts = d_prog.insts_counts(); + if (num_regex_insts <= RX_SMALL_INSTS) { + return f.template operator()(std::forward(args)...); + } + if (num_regex_insts <= RX_MEDIUM_INSTS) { + return f.template operator()(std::forward(args)...); + } + if (num_regex_insts <= RX_LARGE_INSTS) { + return f.template operator()(std::forward(args)...); + } + + return f.template operator()(std::forward(args)...); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index ff86d7aa552..27e0bd4fac9 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "backref_re.cuh" +#include #include #include @@ -95,27 +96,54 @@ std::pair> parse_backrefs(std::string con return {rtn, backrefs}; } +template +struct replace_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(strings_column_view const& input, + string_view const& d_repl_template, + Iterator backrefs_begin, + Iterator backrefs_end, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto const d_strings = column_device_view::create(input.parent(), stream); + + auto children = make_strings_children( + backrefs_fn{ + *d_strings, d_prog, d_repl_template, backrefs_begin, backrefs_end}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); + } +}; + } // namespace // std::unique_ptr replace_with_backrefs( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, std::string const& replacement, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - if (strings.is_empty()) return make_empty_column(type_id::STRING); + if (input.is_empty()) return make_empty_column(type_id::STRING); CUDF_EXPECTS(!pattern.empty(), "Parameter pattern must not be empty"); CUDF_EXPECTS(!replacement.empty(), "Parameter replacement must not be empty"); - auto d_strings = column_device_view::create(strings.parent(), stream); // compile regex into device object auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings.size(), stream); - auto const regex_insts = d_prog->insts_counts(); + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); // parse the repl string for back-ref indicators auto const parse_result = parse_backrefs(replacement); @@ -125,45 +153,14 @@ std::unique_ptr replace_with_backrefs( string_view const d_repl_template = repl_scalar.value(); using BackRefIterator = decltype(backrefs.begin()); - - // create child columns - auto [offsets, chars] = [&] { - if (regex_insts <= RX_SMALL_INSTS) { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } else if (regex_insts <= RX_LARGE_INSTS) { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } else { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } - }(); - - return make_strings_column(strings.size(), - std::move(offsets), - std::move(chars), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); + return regex_dispatcher(*d_prog, + replace_dispatch_fn{*d_prog}, + input, + d_repl_template, + backrefs.begin(), + backrefs.end(), + stream, + mr); } } // namespace detail diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index 2b5380b76dd..22f6d2cba39 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -30,6 +31,8 @@ #include +#include + namespace cudf { namespace strings { namespace detail { @@ -40,16 +43,6 @@ using found_range = thrust::pair; /** * @brief This functor handles replacing strings by applying the compiled regex patterns * and inserting the corresponding new string within the matched range of characters. - * - * The logic includes computing the size of each string and also writing the output. - * - * The stack is used to keep progress on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are three call types based on the number of regex instructions in the given pattern. - * Small to medium instruction lengths can use the stack effectively though smaller executes faster. - * Longer patterns require global memory. Shorter patterns are common in data cleaning. */ template struct replace_multi_regex_fn { @@ -127,69 +120,76 @@ struct replace_multi_regex_fn { } }; +struct replace_dispatch_fn { + template + std::unique_ptr operator()(strings_column_view const& input, + device_span d_progs, + strings_column_view const& replacements, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto const d_strings = column_device_view::create(input.parent(), stream); + auto const d_repls = column_device_view::create(replacements.parent(), stream); + + auto found_ranges = rmm::device_uvector(d_progs.size() * input.size(), stream); + + auto children = make_strings_children( + replace_multi_regex_fn{*d_strings, d_progs, found_ranges.data(), *d_repls}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); + } +}; + } // namespace std::unique_ptr replace_re( - strings_column_view const& strings, + strings_column_view const& input, std::vector const& patterns, strings_column_view const& replacements, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto strings_count = strings.size(); - if (strings_count == 0) return make_empty_column(type_id::STRING); - if (patterns.empty()) // no patterns; just return a copy - return std::make_unique(strings.parent(), stream, mr); + if (input.is_empty()) { return make_empty_column(type_id::STRING); } + if (patterns.empty()) { // if no patterns; just return a copy + return std::make_unique(input.parent(), stream, mr); + } CUDF_EXPECTS(!replacements.has_nulls(), "Parameter replacements must not have any nulls"); - auto d_strings = column_device_view::create(strings.parent(), stream); - auto d_repls = column_device_view::create(replacements.parent(), stream); - auto d_char_table = get_character_flags_table(); - // compile regexes into device objects - size_type regex_insts = 0; - std::vector>> h_progs; - std::vector progs; - for (auto itr = patterns.begin(); itr != patterns.end(); ++itr) { - auto prog = reprog_device::create(*itr, flags, d_char_table, strings_count, stream); - regex_insts = std::max(regex_insts, prog->insts_counts()); - progs.push_back(*prog); - h_progs.emplace_back(std::move(prog)); - } + auto const d_char_table = get_character_flags_table(); + auto h_progs = std::vector>>( + patterns.size()); + std::transform(patterns.begin(), + patterns.end(), + h_progs.begin(), + [flags, d_char_table, input, stream](auto const& ptn) { + return reprog_device::create(ptn, flags, d_char_table, input.size(), stream); + }); + + // get the longest regex for the dispatcher + auto const max_prog = + std::max_element(h_progs.begin(), h_progs.end(), [](auto const& lhs, auto const& rhs) { + return lhs->insts_counts() < rhs->insts_counts(); + }); // copy all the reprog_device instances to a device memory array + std::vector progs; + std::transform(h_progs.begin(), h_progs.end(), std::back_inserter(progs), [](auto const& d_prog) { + return *d_prog; + }); auto d_progs = cudf::detail::make_device_uvector_async(progs, stream); - // create working buffer for ranges pairs - rmm::device_uvector found_ranges(patterns.size() * strings_count, stream); - auto d_found_ranges = found_ranges.data(); - - // create child columns - auto children = [&] { - // Each invocation is predicated on the stack size which is dependent on the number of regex - // instructions - if (regex_insts <= RX_SMALL_INSTS) { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_LARGE_INSTS) { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } else { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } - }(); - - return make_strings_column(strings_count, - std::move(children.first), - std::move(children.second), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); + return regex_dispatcher( + **max_prog, replace_dispatch_fn{}, input, d_progs, replacements, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index 2c594bb86a8..d42359deeac 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -36,16 +37,6 @@ namespace { /** * @brief This functor handles replacing strings by applying the compiled regex pattern * and inserting the new string within the matched range of characters. - * - * The logic includes computing the size of each string and also writing the output. - * - * The stack is used to keep progress on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are three call types based on the number of regex instructions in the given pattern. - * Small to medium instruction lengths can use the stack effectively though smaller executes faster. - * Longer patterns require global memory. Shorter patterns are common in data cleaning. */ template struct replace_regex_fn { @@ -108,11 +99,37 @@ struct replace_regex_fn { } }; +struct replace_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(strings_column_view const& input, + string_view const& d_replacement, + size_type max_replace_count, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto const d_strings = column_device_view::create(input.parent(), stream); + + auto children = make_strings_children( + replace_regex_fn{*d_strings, d_prog, d_replacement, max_replace_count}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); + } +}; + } // namespace // std::unique_ptr replace_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, string_scalar const& replacement, std::optional max_replace_count, @@ -120,49 +137,19 @@ std::unique_ptr replace_re( rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto strings_count = strings.size(); - if (strings_count == 0) return make_empty_column(type_id::STRING); + if (input.is_empty()) return make_empty_column(type_id::STRING); CUDF_EXPECTS(replacement.is_valid(stream), "Parameter replacement must be valid"); string_view d_repl(replacement.data(), replacement.size()); - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_strings = *strings_column; // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - auto const regex_insts = d_prog.insts_counts(); - - // copy null mask - auto null_mask = cudf::detail::copy_bitmask(strings.parent(), stream, mr); - auto const null_count = strings.null_count(); - auto const maxrepl = max_replace_count.value_or(-1); - - // create child columns - auto children = [&] { - // Each invocation is predicated on the stack size which is dependent on the number of regex - // instructions - if (regex_insts <= RX_SMALL_INSTS) { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_LARGE_INSTS) { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } else { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } - }(); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + auto const maxrepl = max_replace_count.value_or(-1); - return make_strings_column(strings_count, - std::move(children.first), - std::move(children.second), - null_count, - std::move(null_mask)); + return regex_dispatcher( + *d_prog, replace_dispatch_fn{*d_prog}, input, d_repl, maxrepl, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 810e44cc27d..201556033ad 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -14,6 +14,11 @@ * limitations under the License. */ +#include +#include +#include +#include + #include #include #include @@ -24,19 +29,16 @@ #include #include -#include -#include - #include #include -#include +#include +#include namespace cudf { namespace strings { namespace detail { using string_index_pair = thrust::pair; -using findall_result = thrust::pair; namespace { /** @@ -47,27 +49,20 @@ template struct findall_fn { column_device_view const d_strings; reprog_device prog; - size_type column_index; + size_type const column_index; size_type const* d_counts; - findall_fn(column_device_view const& d_strings, - reprog_device& prog, - size_type column_index = -1, - size_type const* d_counts = nullptr) - : d_strings(d_strings), prog(prog), column_index(column_index), d_counts(d_counts) + __device__ string_index_pair operator()(size_type idx) { - } + if (d_strings.is_null(idx) || (column_index >= d_counts[idx])) { + return string_index_pair{nullptr, 0}; + } + + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); + int32_t spos = 0; + auto epos = static_cast(nchars); - // this will count columns as well as locate a specific string for a column - __device__ findall_result findall(size_type idx) - { - string_index_pair result{nullptr, 0}; - if (d_strings.is_null(idx) || (d_counts && (column_index >= d_counts[idx]))) - return findall_result{0, result}; - string_view d_str = d_strings.element(idx); - auto const nchars = d_str.length(); - int32_t spos = 0; - auto epos = static_cast(nchars); size_type column_count = 0; while (spos <= nchars) { if (prog.find(idx, d_str, spos, epos) <= 0) break; // no more matches found @@ -76,36 +71,40 @@ struct findall_fn { epos = static_cast(nchars); ++column_count; } - if (spos <= epos) { - spos = d_str.byte_offset(spos); // convert - epos = d_str.byte_offset(epos); // to bytes - result = string_index_pair{d_str.data() + spos, (epos - spos)}; - } - // return the strings location and the column count - return findall_result{column_count, result}; - } - __device__ string_index_pair operator()(size_type idx) - { - // this one only cares about the string - return findall(idx).second; + auto const result = [&] { + if (spos > epos) { return string_index_pair{nullptr, 0}; } + // convert character positions to byte positions + spos = d_str.byte_offset(spos); + epos = d_str.byte_offset(epos); + return string_index_pair{d_str.data() + spos, (epos - spos)}; + }(); + + return result; } }; -template -struct findall_count_fn : public findall_fn { - findall_count_fn(column_device_view const& strings, reprog_device& prog) - : findall_fn{strings, prog} - { - } +struct findall_dispatch_fn { + reprog_device d_prog; - __device__ size_type operator()(size_type idx) + template + std::unique_ptr operator()(column_device_view const& d_strings, + size_type column_index, + size_type const* d_find_counts, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { - // this one only cares about the column count - return findall_fn::findall(idx).first; + rmm::device_uvector indices(d_strings.size(), stream); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(d_strings.size()), + indices.begin(), + findall_fn{d_strings, d_prog, column_index, d_find_counts}); + + return make_strings_column(indices.begin(), indices.end(), stream, mr); } }; - } // namespace // @@ -124,38 +123,15 @@ std::unique_ptr
findall( reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); auto const regex_insts = d_prog->insts_counts(); - rmm::device_uvector find_counts(strings_count, stream); - auto d_find_counts = find_counts.data(); - - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); - else - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); + auto find_counts = + count_matches(*d_strings, *d_prog, stream, rmm::mr::get_current_device_resource()); + auto d_find_counts = find_counts->mutable_view().data(); std::vector> results; size_type const columns = thrust::reduce( - rmm::exec_policy(stream), find_counts.begin(), find_counts.end(), 0, thrust::maximum{}); + rmm::exec_policy(stream), d_find_counts, d_find_counts + strings_count, 0, thrust::maximum{}); + // boundary case: if no columns, return all nulls column (issue #119) if (columns == 0) results.emplace_back(std::make_unique( @@ -166,39 +142,10 @@ std::unique_ptr
findall( strings_count)); for (int32_t column_index = 0; column_index < columns; ++column_index) { - rmm::device_uvector indices(strings_count, stream); - - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - else - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - - // - results.emplace_back(make_strings_column(indices.begin(), indices.end(), stream, mr)); + results.emplace_back(regex_dispatcher( + *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, column_index, d_find_counts, stream, mr)); } + return std::make_unique
(std::move(results)); } diff --git a/cpp/src/strings/search/findall_record.cu b/cpp/src/strings/search/findall_record.cu index c93eb0c17db..95e347a7c35 100644 --- a/cpp/src/strings/search/findall_record.cu +++ b/cpp/src/strings/search/findall_record.cu @@ -15,6 +15,9 @@ */ #include +#include +#include +#include #include #include @@ -26,9 +29,6 @@ #include #include -#include -#include - #include #include @@ -75,6 +75,27 @@ struct findall_fn { } }; +struct findall_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(column_device_view const& d_strings, + size_type total_matches, + offset_type const* d_offsets, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + rmm::device_uvector indices(total_matches, stream); + + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + findall_fn{d_strings, d_prog, d_offsets, indices.data()}); + + return make_strings_column(indices.begin(), indices.end(), stream, mr); + } +}; + } // namespace // @@ -121,30 +142,11 @@ std::unique_ptr findall_record( rmm::exec_policy(stream), d_offsets, d_offsets + strings_count + 1, d_offsets); // Create indices vector with the total number of groups that will be extracted - auto total_matches = cudf::detail::get_value(offsets->view(), strings_count, stream); - - rmm::device_uvector indices(total_matches, stream); - auto d_indices = indices.data(); - auto begin = thrust::make_counting_iterator(0); - - // Build the string indices - auto const regex_insts = d_prog->insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_LARGE_INSTS) { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } + auto const total_matches = + cudf::detail::get_value(offsets->view(), strings_count, stream); - // Build the child strings column from the resulting indices - auto strings_output = make_strings_column(indices.begin(), indices.end(), stream, mr); + auto strings_output = regex_dispatcher( + *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, total_matches, d_offsets, stream, mr); // Build the lists column from the offsets and the strings return make_lists_column(strings_count, diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index d80148f2fe6..a8a2467dd76 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -110,6 +111,28 @@ struct token_reader_fn { } }; +struct generate_dispatch_fn { + reprog_device d_prog; + + template + rmm::device_uvector operator()(column_device_view const& d_strings, + size_type total_tokens, + split_direction direction, + offset_type const* d_offsets, + rmm::cuda_stream_view stream) + { + rmm::device_uvector tokens(total_tokens, stream); + + thrust::for_each_n( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + token_reader_fn{d_strings, d_prog, direction, d_offsets, tokens.data()}); + + return tokens; + } +}; + /** * @brief Call regex to split each input string into tokens. * @@ -148,24 +171,8 @@ rmm::device_uvector generate_tokens(column_device_view const& // the last offset entry is the total number of tokens to be generated auto const total_tokens = cudf::detail::get_value(offsets, strings_count, stream); - // generate tokens for each string - rmm::device_uvector tokens(total_tokens, stream); - auto const regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } else if (regex_insts <= RX_LARGE_INSTS) { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } else { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } - - return tokens; + return regex_dispatcher( + d_prog, generate_dispatch_fn{d_prog}, d_strings, total_tokens, direction, d_offsets, stream); } /**