diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index efdee65c1f6..23bc5cf2dfe 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,10 @@ * limitations under the License. */ +#include +#include +#include + #include #include #include @@ -23,123 +27,90 @@ #include #include #include -#include -#include #include #include +#include + namespace cudf { namespace strings { namespace detail { + namespace { /** * @brief This functor handles both contains_re and match_re to minimize the number * of regex calls to find() to be inlined greatly reducing compile time. - * - * The stack is used to keep progress on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are three call types based on the number of regex instructions in the given pattern. - * Small to medium instruction lengths can use the stack effectively though smaller executes faster. - * Longer patterns require global memory. */ template struct contains_fn { reprog_device prog; - column_device_view d_strings; - bool bmatch{false}; // do not make this a template parameter to keep compile times down + column_device_view const d_strings; + bool const beginning_only; // do not make this a template parameter to keep compile times down __device__ bool operator()(size_type idx) { if (d_strings.is_null(idx)) return false; - string_view d_str = d_strings.element(idx); - int32_t begin = 0; - int32_t end = bmatch ? 1 // match only the beginning of the string; - : -1; // this handles empty strings too + auto const d_str = d_strings.element(idx); + int32_t begin = 0; + int32_t end = beginning_only ? 1 // match only the beginning of the string; + : -1; // match anywhere in the string return static_cast(prog.find(idx, d_str, begin, end)); } }; -// -std::unique_ptr contains_util( - strings_column_view const& strings, - std::string const& pattern, - regex_flags const flags, - bool beginning_only = false, - rmm::cuda_stream_view stream = rmm::cuda_stream_default, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) -{ - auto strings_count = strings.size(); - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_column = *strings_column; - - // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - - // create the output column - auto results = make_numeric_column(data_type{type_id::BOOL8}, - strings_count, - cudf::detail::copy_bitmask(strings.parent(), stream, mr), - strings.null_count(), - stream, - mr); - auto d_results = results->mutable_view().data(); +struct contains_dispatch_fn { + reprog_device d_prog; + bool const beginning_only; - // fill the output column - int regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - else + template + std::unique_ptr operator()(strings_column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto results = make_numeric_column(data_type{type_id::BOOL8}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + + auto const d_strings = column_device_view::create(input.parent(), stream); thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - contains_fn{d_prog, d_column, beginning_only}); - - results->set_null_count(strings.null_count()); - return results; -} + thrust::make_counting_iterator(input.size()), + results->mutable_view().data(), + contains_fn{d_prog, *d_strings, beginning_only}); + return results; + } +}; } // namespace std::unique_ptr contains_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - return contains_util(strings, pattern, flags, false, stream, mr); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, false}, input, stream, mr); } std::unique_ptr matches_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - return contains_util(strings, pattern, flags, true, stream, mr); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr); } } // namespace detail @@ -172,12 +143,12 @@ namespace { template struct count_fn { reprog_device prog; - column_device_view d_strings; + column_device_view const d_strings; __device__ int32_t operator()(unsigned int idx) { if (d_strings.is_null(idx)) return 0; - string_view d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); auto const nchars = d_str.length(); int32_t find_count = 0; int32_t begin = 0; @@ -191,62 +162,45 @@ struct count_fn { } }; +struct count_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(strings_column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto results = make_numeric_column(data_type{type_id::INT32}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + + auto const d_strings = column_device_view::create(input.parent(), stream); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.size()), + results->mutable_view().data(), + count_fn{d_prog, *d_strings}); + return results; + } +}; + } // namespace std::unique_ptr count_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto strings_count = strings.size(); - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_column = *strings_column; - // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - - // create the output column - auto results = make_numeric_column(data_type{type_id::INT32}, - strings_count, - cudf::detail::copy_bitmask(strings.parent(), stream, mr), - strings.null_count(), - stream, - mr); - auto d_results = results->mutable_view().data(); - - // fill the output column - int regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); - else - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_results, - count_fn{d_prog, d_column}); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); - results->set_null_count(strings.null_count()); - return results; + return regex_dispatcher(*d_prog, count_dispatch_fn{*d_prog}, input, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/count_matches.cu b/cpp/src/strings/count_matches.cu index d0a6825666b..ae996cafd2c 100644 --- a/cpp/src/strings/count_matches.cu +++ b/cpp/src/strings/count_matches.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -54,6 +55,27 @@ struct count_matches_fn { return count; } }; + +struct count_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(column_device_view const& d_strings, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto results = make_numeric_column( + data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(d_strings.size()), + results->mutable_view().data(), + count_matches_fn{d_strings, d_prog}); + return results; + } +}; + } // namespace /** @@ -71,31 +93,7 @@ std::unique_ptr count_matches(column_device_view const& d_strings, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - // Create output column - auto counts = make_numeric_column( - data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr); - auto d_counts = counts->mutable_view().data(); - - auto begin = thrust::make_counting_iterator(0); - auto end = thrust::make_counting_iterator(d_strings.size()); - - // Count matches - auto const regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } else if (regex_insts <= RX_LARGE_INSTS) { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } else { - count_matches_fn fn{d_strings, d_prog}; - thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn); - } - - return counts; + return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index a67af9442f0..7394cdac6bb 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -77,53 +78,44 @@ struct extract_fn { thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0}); } }; + +struct extract_dispatch_fn { + reprog_device d_prog; + + template + void operator()(column_device_view const& d_strings, + cudf::detail::device_2dspan& d_indices, + rmm::cuda_stream_view stream) + { + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + extract_fn{d_prog, d_strings, d_indices}); + } +}; } // namespace // std::unique_ptr extract( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto const strings_count = strings.size(); - auto const strings_column = column_device_view::create(strings.parent(), stream); - auto const d_strings = *strings_column; - // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - // extract should include groups - auto const groups = d_prog.group_counts(); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + auto const groups = d_prog->group_counts(); CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern"); - rmm::device_uvector indices(strings_count * groups, stream); - cudf::detail::device_2dspan d_indices(indices.data(), strings_count, groups); + auto indices = rmm::device_uvector(input.size() * groups, stream); + auto d_indices = + cudf::detail::device_2dspan(indices.data(), input.size(), groups); - auto const regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } else if (regex_insts <= RX_LARGE_INSTS) { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } else { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - extract_fn{d_prog, d_strings, d_indices}); - } + auto const d_strings = column_device_view::create(input.parent(), stream); + regex_dispatcher(*d_prog, extract_dispatch_fn{*d_prog}, *d_strings, d_indices, stream); // build a result column for each group std::vector> results(groups); @@ -135,7 +127,7 @@ std::unique_ptr
extract( 0, [column_index, groups] __device__(size_type idx) { return (idx * groups) + column_index; })); - return make_strings_column(indices_itr, indices_itr + strings_count, stream, mr); + return make_strings_column(indices_itr, indices_itr + input.size(), stream, mr); }; std::transform(thrust::make_counting_iterator(0), diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index e27dccb9338..1f1474c777b 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -86,6 +87,28 @@ struct extract_fn { } } }; + +struct extract_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(column_device_view const& d_strings, + size_type total_groups, + offset_type const* d_offsets, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + rmm::device_uvector indices(total_groups, stream); + + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + extract_fn{d_strings, d_prog, d_offsets, indices.data()}); + + return make_strings_column(indices.begin(), indices.end(), stream, mr); + } +}; + } // namespace /** @@ -94,14 +117,14 @@ struct extract_fn { * @param stream CUDA stream used for device memory operations and kernel launches. */ std::unique_ptr extract_all_record( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto const strings_count = strings.size(); - auto const d_strings = column_device_view::create(strings.parent(), stream); + auto const strings_count = input.size(); + auto const d_strings = column_device_view::create(input.parent(), stream); // Compile regex into device object. auto d_prog = @@ -143,29 +166,8 @@ std::unique_ptr extract_all_record( auto const total_groups = cudf::detail::get_value(offsets->view(), strings_count, stream); - // Create an indices vector with the total number of groups that will be extracted. - rmm::device_uvector indices(total_groups, stream); - auto d_indices = indices.data(); - auto begin = thrust::make_counting_iterator(0); - - // Call the extract functor to fill in the indices vector. - auto const regex_insts = d_prog->insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_LARGE_INSTS) { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else { - extract_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } - - // Build the child strings column from the indices. - auto strings_output = make_strings_column(indices.begin(), indices.end(), stream, mr); + auto strings_output = regex_dispatcher( + *d_prog, extract_dispatch_fn{*d_prog}, *d_strings, total_groups, d_offsets, stream, mr); // Build the lists column from the offsets and the strings. return make_lists_column(strings_count, diff --git a/cpp/src/strings/regex/dispatcher.hpp b/cpp/src/strings/regex/dispatcher.hpp new file mode 100644 index 00000000000..9ff51d1c979 --- /dev/null +++ b/cpp/src/strings/regex/dispatcher.hpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf { +namespace strings { +namespace detail { + +/** + * The stack is used to keep progress (state) on evaluating the regex instructions on each string. + * So the size of the stack is in proportion to the number of instructions in the given regex + * pattern. + * + * There are four call types based on the number of regex instructions in the given pattern. + * Small, medium, and large instruction counts can use the stack effectively. + * Smaller stack sizes execute faster. + * + * Patterns with instruction counts bigger than large use global memory rather than the stack + * for managing the evaluation state data. + * + * @tparam Functor The functor to invoke with stack size templated value. + * @tparam Ts Parameter types for the functor call. + */ +template +constexpr decltype(auto) regex_dispatcher(reprog_device d_prog, Functor f, Ts&&... args) +{ + auto const num_regex_insts = d_prog.insts_counts(); + if (num_regex_insts <= RX_SMALL_INSTS) { + return f.template operator()(std::forward(args)...); + } + if (num_regex_insts <= RX_MEDIUM_INSTS) { + return f.template operator()(std::forward(args)...); + } + if (num_regex_insts <= RX_LARGE_INSTS) { + return f.template operator()(std::forward(args)...); + } + + return f.template operator()(std::forward(args)...); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index ff86d7aa552..27e0bd4fac9 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "backref_re.cuh" +#include #include #include @@ -95,27 +96,54 @@ std::pair> parse_backrefs(std::string con return {rtn, backrefs}; } +template +struct replace_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(strings_column_view const& input, + string_view const& d_repl_template, + Iterator backrefs_begin, + Iterator backrefs_end, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto const d_strings = column_device_view::create(input.parent(), stream); + + auto children = make_strings_children( + backrefs_fn{ + *d_strings, d_prog, d_repl_template, backrefs_begin, backrefs_end}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); + } +}; + } // namespace // std::unique_ptr replace_with_backrefs( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, std::string const& replacement, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - if (strings.is_empty()) return make_empty_column(type_id::STRING); + if (input.is_empty()) return make_empty_column(type_id::STRING); CUDF_EXPECTS(!pattern.empty(), "Parameter pattern must not be empty"); CUDF_EXPECTS(!replacement.empty(), "Parameter replacement must not be empty"); - auto d_strings = column_device_view::create(strings.parent(), stream); // compile regex into device object auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings.size(), stream); - auto const regex_insts = d_prog->insts_counts(); + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); // parse the repl string for back-ref indicators auto const parse_result = parse_backrefs(replacement); @@ -125,45 +153,14 @@ std::unique_ptr replace_with_backrefs( string_view const d_repl_template = repl_scalar.value(); using BackRefIterator = decltype(backrefs.begin()); - - // create child columns - auto [offsets, chars] = [&] { - if (regex_insts <= RX_SMALL_INSTS) { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } else if (regex_insts <= RX_LARGE_INSTS) { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } else { - return make_strings_children( - backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, - strings.size(), - stream, - mr); - } - }(); - - return make_strings_column(strings.size(), - std::move(offsets), - std::move(chars), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); + return regex_dispatcher(*d_prog, + replace_dispatch_fn{*d_prog}, + input, + d_repl_template, + backrefs.begin(), + backrefs.end(), + stream, + mr); } } // namespace detail diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index 2b5380b76dd..22f6d2cba39 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -30,6 +31,8 @@ #include +#include + namespace cudf { namespace strings { namespace detail { @@ -40,16 +43,6 @@ using found_range = thrust::pair; /** * @brief This functor handles replacing strings by applying the compiled regex patterns * and inserting the corresponding new string within the matched range of characters. - * - * The logic includes computing the size of each string and also writing the output. - * - * The stack is used to keep progress on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are three call types based on the number of regex instructions in the given pattern. - * Small to medium instruction lengths can use the stack effectively though smaller executes faster. - * Longer patterns require global memory. Shorter patterns are common in data cleaning. */ template struct replace_multi_regex_fn { @@ -127,69 +120,76 @@ struct replace_multi_regex_fn { } }; +struct replace_dispatch_fn { + template + std::unique_ptr operator()(strings_column_view const& input, + device_span d_progs, + strings_column_view const& replacements, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto const d_strings = column_device_view::create(input.parent(), stream); + auto const d_repls = column_device_view::create(replacements.parent(), stream); + + auto found_ranges = rmm::device_uvector(d_progs.size() * input.size(), stream); + + auto children = make_strings_children( + replace_multi_regex_fn{*d_strings, d_progs, found_ranges.data(), *d_repls}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); + } +}; + } // namespace std::unique_ptr replace_re( - strings_column_view const& strings, + strings_column_view const& input, std::vector const& patterns, strings_column_view const& replacements, regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto strings_count = strings.size(); - if (strings_count == 0) return make_empty_column(type_id::STRING); - if (patterns.empty()) // no patterns; just return a copy - return std::make_unique(strings.parent(), stream, mr); + if (input.is_empty()) { return make_empty_column(type_id::STRING); } + if (patterns.empty()) { // if no patterns; just return a copy + return std::make_unique(input.parent(), stream, mr); + } CUDF_EXPECTS(!replacements.has_nulls(), "Parameter replacements must not have any nulls"); - auto d_strings = column_device_view::create(strings.parent(), stream); - auto d_repls = column_device_view::create(replacements.parent(), stream); - auto d_char_table = get_character_flags_table(); - // compile regexes into device objects - size_type regex_insts = 0; - std::vector>> h_progs; - std::vector progs; - for (auto itr = patterns.begin(); itr != patterns.end(); ++itr) { - auto prog = reprog_device::create(*itr, flags, d_char_table, strings_count, stream); - regex_insts = std::max(regex_insts, prog->insts_counts()); - progs.push_back(*prog); - h_progs.emplace_back(std::move(prog)); - } + auto const d_char_table = get_character_flags_table(); + auto h_progs = std::vector>>( + patterns.size()); + std::transform(patterns.begin(), + patterns.end(), + h_progs.begin(), + [flags, d_char_table, input, stream](auto const& ptn) { + return reprog_device::create(ptn, flags, d_char_table, input.size(), stream); + }); + + // get the longest regex for the dispatcher + auto const max_prog = + std::max_element(h_progs.begin(), h_progs.end(), [](auto const& lhs, auto const& rhs) { + return lhs->insts_counts() < rhs->insts_counts(); + }); // copy all the reprog_device instances to a device memory array + std::vector progs; + std::transform(h_progs.begin(), h_progs.end(), std::back_inserter(progs), [](auto const& d_prog) { + return *d_prog; + }); auto d_progs = cudf::detail::make_device_uvector_async(progs, stream); - // create working buffer for ranges pairs - rmm::device_uvector found_ranges(patterns.size() * strings_count, stream); - auto d_found_ranges = found_ranges.data(); - - // create child columns - auto children = [&] { - // Each invocation is predicated on the stack size which is dependent on the number of regex - // instructions - if (regex_insts <= RX_SMALL_INSTS) { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_LARGE_INSTS) { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } else { - replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; - return make_strings_children(fn, strings_count, stream, mr); - } - }(); - - return make_strings_column(strings_count, - std::move(children.first), - std::move(children.second), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); + return regex_dispatcher( + **max_prog, replace_dispatch_fn{}, input, d_progs, replacements, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index 2c594bb86a8..d42359deeac 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -36,16 +37,6 @@ namespace { /** * @brief This functor handles replacing strings by applying the compiled regex pattern * and inserting the new string within the matched range of characters. - * - * The logic includes computing the size of each string and also writing the output. - * - * The stack is used to keep progress on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are three call types based on the number of regex instructions in the given pattern. - * Small to medium instruction lengths can use the stack effectively though smaller executes faster. - * Longer patterns require global memory. Shorter patterns are common in data cleaning. */ template struct replace_regex_fn { @@ -108,11 +99,37 @@ struct replace_regex_fn { } }; +struct replace_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(strings_column_view const& input, + string_view const& d_replacement, + size_type max_replace_count, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto const d_strings = column_device_view::create(input.parent(), stream); + + auto children = make_strings_children( + replace_regex_fn{*d_strings, d_prog, d_replacement, max_replace_count}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); + } +}; + } // namespace // std::unique_ptr replace_re( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, string_scalar const& replacement, std::optional max_replace_count, @@ -120,49 +137,19 @@ std::unique_ptr replace_re( rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto strings_count = strings.size(); - if (strings_count == 0) return make_empty_column(type_id::STRING); + if (input.is_empty()) return make_empty_column(type_id::STRING); CUDF_EXPECTS(replacement.is_valid(stream), "Parameter replacement must be valid"); string_view d_repl(replacement.data(), replacement.size()); - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_strings = *strings_column; // compile regex into device object - auto prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - auto const regex_insts = d_prog.insts_counts(); - - // copy null mask - auto null_mask = cudf::detail::copy_bitmask(strings.parent(), stream, mr); - auto const null_count = strings.null_count(); - auto const maxrepl = max_replace_count.value_or(-1); - - // create child columns - auto children = [&] { - // Each invocation is predicated on the stack size which is dependent on the number of regex - // instructions - if (regex_insts <= RX_SMALL_INSTS) { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } else if (regex_insts <= RX_LARGE_INSTS) { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } else { - replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; - return make_strings_children(fn, strings_count, stream, mr); - } - }(); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + auto const maxrepl = max_replace_count.value_or(-1); - return make_strings_column(strings_count, - std::move(children.first), - std::move(children.second), - null_count, - std::move(null_mask)); + return regex_dispatcher( + *d_prog, replace_dispatch_fn{*d_prog}, input, d_repl, maxrepl, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 810e44cc27d..201556033ad 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -14,6 +14,11 @@ * limitations under the License. */ +#include +#include +#include +#include + #include #include #include @@ -24,19 +29,16 @@ #include #include -#include -#include - #include #include -#include +#include +#include namespace cudf { namespace strings { namespace detail { using string_index_pair = thrust::pair; -using findall_result = thrust::pair; namespace { /** @@ -47,27 +49,20 @@ template struct findall_fn { column_device_view const d_strings; reprog_device prog; - size_type column_index; + size_type const column_index; size_type const* d_counts; - findall_fn(column_device_view const& d_strings, - reprog_device& prog, - size_type column_index = -1, - size_type const* d_counts = nullptr) - : d_strings(d_strings), prog(prog), column_index(column_index), d_counts(d_counts) + __device__ string_index_pair operator()(size_type idx) { - } + if (d_strings.is_null(idx) || (column_index >= d_counts[idx])) { + return string_index_pair{nullptr, 0}; + } + + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); + int32_t spos = 0; + auto epos = static_cast(nchars); - // this will count columns as well as locate a specific string for a column - __device__ findall_result findall(size_type idx) - { - string_index_pair result{nullptr, 0}; - if (d_strings.is_null(idx) || (d_counts && (column_index >= d_counts[idx]))) - return findall_result{0, result}; - string_view d_str = d_strings.element(idx); - auto const nchars = d_str.length(); - int32_t spos = 0; - auto epos = static_cast(nchars); size_type column_count = 0; while (spos <= nchars) { if (prog.find(idx, d_str, spos, epos) <= 0) break; // no more matches found @@ -76,36 +71,40 @@ struct findall_fn { epos = static_cast(nchars); ++column_count; } - if (spos <= epos) { - spos = d_str.byte_offset(spos); // convert - epos = d_str.byte_offset(epos); // to bytes - result = string_index_pair{d_str.data() + spos, (epos - spos)}; - } - // return the strings location and the column count - return findall_result{column_count, result}; - } - __device__ string_index_pair operator()(size_type idx) - { - // this one only cares about the string - return findall(idx).second; + auto const result = [&] { + if (spos > epos) { return string_index_pair{nullptr, 0}; } + // convert character positions to byte positions + spos = d_str.byte_offset(spos); + epos = d_str.byte_offset(epos); + return string_index_pair{d_str.data() + spos, (epos - spos)}; + }(); + + return result; } }; -template -struct findall_count_fn : public findall_fn { - findall_count_fn(column_device_view const& strings, reprog_device& prog) - : findall_fn{strings, prog} - { - } +struct findall_dispatch_fn { + reprog_device d_prog; - __device__ size_type operator()(size_type idx) + template + std::unique_ptr operator()(column_device_view const& d_strings, + size_type column_index, + size_type const* d_find_counts, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { - // this one only cares about the column count - return findall_fn::findall(idx).first; + rmm::device_uvector indices(d_strings.size(), stream); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(d_strings.size()), + indices.begin(), + findall_fn{d_strings, d_prog, column_index, d_find_counts}); + + return make_strings_column(indices.begin(), indices.end(), stream, mr); } }; - } // namespace // @@ -124,38 +123,15 @@ std::unique_ptr
findall( reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); auto const regex_insts = d_prog->insts_counts(); - rmm::device_uvector find_counts(strings_count, stream); - auto d_find_counts = find_counts.data(); - - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); - else - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - d_find_counts, - findall_count_fn{*d_strings, *d_prog}); + auto find_counts = + count_matches(*d_strings, *d_prog, stream, rmm::mr::get_current_device_resource()); + auto d_find_counts = find_counts->mutable_view().data(); std::vector> results; size_type const columns = thrust::reduce( - rmm::exec_policy(stream), find_counts.begin(), find_counts.end(), 0, thrust::maximum{}); + rmm::exec_policy(stream), d_find_counts, d_find_counts + strings_count, 0, thrust::maximum{}); + // boundary case: if no columns, return all nulls column (issue #119) if (columns == 0) results.emplace_back(std::make_unique( @@ -166,39 +142,10 @@ std::unique_ptr
findall( strings_count)); for (int32_t column_index = 0; column_index < columns; ++column_index) { - rmm::device_uvector indices(strings_count, stream); - - if (regex_insts <= RX_SMALL_INSTS) - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - else if (regex_insts <= RX_MEDIUM_INSTS) - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - else if (regex_insts <= RX_LARGE_INSTS) - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - else - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - indices.begin(), - findall_fn{*d_strings, *d_prog, column_index, d_find_counts}); - - // - results.emplace_back(make_strings_column(indices.begin(), indices.end(), stream, mr)); + results.emplace_back(regex_dispatcher( + *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, column_index, d_find_counts, stream, mr)); } + return std::make_unique
(std::move(results)); } diff --git a/cpp/src/strings/search/findall_record.cu b/cpp/src/strings/search/findall_record.cu index c93eb0c17db..95e347a7c35 100644 --- a/cpp/src/strings/search/findall_record.cu +++ b/cpp/src/strings/search/findall_record.cu @@ -15,6 +15,9 @@ */ #include +#include +#include +#include #include #include @@ -26,9 +29,6 @@ #include #include -#include -#include - #include #include @@ -75,6 +75,27 @@ struct findall_fn { } }; +struct findall_dispatch_fn { + reprog_device d_prog; + + template + std::unique_ptr operator()(column_device_view const& d_strings, + size_type total_matches, + offset_type const* d_offsets, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + rmm::device_uvector indices(total_matches, stream); + + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + findall_fn{d_strings, d_prog, d_offsets, indices.data()}); + + return make_strings_column(indices.begin(), indices.end(), stream, mr); + } +}; + } // namespace // @@ -121,30 +142,11 @@ std::unique_ptr findall_record( rmm::exec_policy(stream), d_offsets, d_offsets + strings_count + 1, d_offsets); // Create indices vector with the total number of groups that will be extracted - auto total_matches = cudf::detail::get_value(offsets->view(), strings_count, stream); - - rmm::device_uvector indices(total_matches, stream); - auto d_indices = indices.data(); - auto begin = thrust::make_counting_iterator(0); - - // Build the string indices - auto const regex_insts = d_prog->insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else if (regex_insts <= RX_LARGE_INSTS) { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } else { - findall_fn fn{*d_strings, *d_prog, d_offsets, d_indices}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, fn); - } + auto const total_matches = + cudf::detail::get_value(offsets->view(), strings_count, stream); - // Build the child strings column from the resulting indices - auto strings_output = make_strings_column(indices.begin(), indices.end(), stream, mr); + auto strings_output = regex_dispatcher( + *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, total_matches, d_offsets, stream, mr); // Build the lists column from the offsets and the strings return make_lists_column(strings_count, diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index d80148f2fe6..a8a2467dd76 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -110,6 +111,28 @@ struct token_reader_fn { } }; +struct generate_dispatch_fn { + reprog_device d_prog; + + template + rmm::device_uvector operator()(column_device_view const& d_strings, + size_type total_tokens, + split_direction direction, + offset_type const* d_offsets, + rmm::cuda_stream_view stream) + { + rmm::device_uvector tokens(total_tokens, stream); + + thrust::for_each_n( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + d_strings.size(), + token_reader_fn{d_strings, d_prog, direction, d_offsets, tokens.data()}); + + return tokens; + } +}; + /** * @brief Call regex to split each input string into tokens. * @@ -148,24 +171,8 @@ rmm::device_uvector generate_tokens(column_device_view const& // the last offset entry is the total number of tokens to be generated auto const total_tokens = cudf::detail::get_value(offsets, strings_count, stream); - // generate tokens for each string - rmm::device_uvector tokens(total_tokens, stream); - auto const regex_insts = d_prog.insts_counts(); - if (regex_insts <= RX_SMALL_INSTS) { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } else if (regex_insts <= RX_LARGE_INSTS) { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } else { - token_reader_fn reader{d_strings, d_prog, direction, d_offsets, tokens.data()}; - thrust::for_each_n(rmm::exec_policy(stream), begin, strings_count, reader); - } - - return tokens; + return regex_dispatcher( + d_prog, generate_dispatch_fn{d_prog}, d_strings, total_tokens, direction, d_offsets, stream); } /**