From de0f7e0f97ba33df8bb15f2220f071ec6b262f33 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Fri, 6 May 2022 07:59:04 -0400 Subject: [PATCH] Change stack-based regex state data to use global memory (#10600) All libcudf strings regex calls will use global device memory for state data when evaluating regex on strings. Previously, separate templated kernels were used to store state data in fixed size stack memory depending on the number of instructions resolved from the provided regex pattern. This required the CUDA driver to allocate a large amount of device memory for when launching the kernel. This memory is managed by the launcher in the driver and so not under control of RMM. This has been changed to use a memory-resource allocated global device memory to hold and manage the state data per string per instruction. This is an internal change only and results in no behavior changes. Overall, the performance based on the current benchmarks has not changed though much more memory may be required to execute any of the regex APIs depending on the number of instructions in the pattern and the total number of strings in the column. Every effort has been made to not reduce performance from the stack-based approach. Additional optimizations here include copying the `reprog_device` class data to shared-memory (when it fits). Further optimizations are expected in later PRs as well. Overall, the compile time of the files that use regex is also faster since only a single kernel is generated instead of 4 in the templated, stack-based implementation. This PR is dependent on PR #10573. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Mike Wilson (https://github.com/hyperbolic2346) - Jake Hemstad (https://github.com/jrhemstad) URL: https://github.com/rapidsai/cudf/pull/10600 --- cpp/src/strings/contains.cu | 102 +++++++--------- cpp/src/strings/count_matches.cu | 69 ++++------- cpp/src/strings/count_matches.hpp | 3 +- cpp/src/strings/extract/extract.cu | 57 +++------ cpp/src/strings/extract/extract_all.cu | 57 +++------ cpp/src/strings/regex/dispatcher.hpp | 59 --------- cpp/src/strings/regex/regex.cuh | 128 +++++++++++++------- cpp/src/strings/regex/regex.inl | 116 ++++++++++-------- cpp/src/strings/regex/regexec.cu | 71 +++++++---- cpp/src/strings/regex/utilities.cuh | 148 +++++++++++++++++++++++ cpp/src/strings/replace/backref_re.cu | 73 ++++------- cpp/src/strings/replace/backref_re.cuh | 13 +- cpp/src/strings/replace/multi_re.cu | 80 ++++++------ cpp/src/strings/replace/replace_re.cu | 55 +++------ cpp/src/strings/search/findall.cu | 37 ++---- cpp/src/strings/search/findall_record.cu | 49 +++----- cpp/src/strings/split/split_re.cu | 56 +++------ 17 files changed, 578 insertions(+), 595 deletions(-) delete mode 100644 cpp/src/strings/regex/dispatcher.hpp create mode 100644 cpp/src/strings/regex/utilities.cuh diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index c4ffa7f0fb1..987cd076fd0 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -27,13 +25,8 @@ #include #include #include -#include #include -#include - -#include -#include namespace cudf { namespace strings { @@ -41,51 +34,52 @@ namespace detail { namespace { /** - * @brief This functor handles both contains_re and match_re to minimize the number - * of regex calls to find() to be inlined greatly reducing compile time. + * @brief This functor handles both contains_re and match_re to regex-match a pattern + * to each string in a column. */ -template struct contains_fn { - reprog_device prog; column_device_view const d_strings; - bool const beginning_only; // do not make this a template parameter to keep compile times down + bool const beginning_only; - __device__ bool operator()(size_type idx) + __device__ bool operator()(size_type const idx, + reprog_device const prog, + int32_t const thread_idx) { if (d_strings.is_null(idx)) return false; auto const d_str = d_strings.element(idx); - int32_t begin = 0; - int32_t end = beginning_only ? 1 // match only the beginning of the string; - : -1; // match anywhere in the string - return static_cast(prog.find(idx, d_str, begin, end)); + + size_type begin = 0; + size_type end = beginning_only ? 1 // match only the beginning of the string; + : -1; // match anywhere in the string + return static_cast(prog.find(thread_idx, d_str, begin, end)); } }; -struct contains_dispatch_fn { - reprog_device d_prog; - bool const beginning_only; +std::unique_ptr contains_impl(strings_column_view const& input, + std::string const& pattern, + regex_flags const flags, + bool const beginning_only, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto results = make_numeric_column(data_type{type_id::BOOL8}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + if (input.is_empty()) { return results; } - template - std::unique_ptr operator()(strings_column_view const& input, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto results = make_numeric_column(data_type{type_id::BOOL8}, - input.size(), - cudf::detail::copy_bitmask(input.parent(), stream, mr), - input.null_count(), - stream, - mr); - - auto const d_strings = column_device_view::create(input.parent(), stream); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(input.size()), - results->mutable_view().data(), - contains_fn{d_prog, *d_strings, beginning_only}); - return results; - } -}; + auto d_prog = reprog_device::create(pattern, flags, stream); + + auto d_results = results->mutable_view().data(); + auto const d_strings = column_device_view::create(input.parent(), stream); + + launch_transform_kernel( + contains_fn{*d_strings, beginning_only}, *d_prog, d_results, input.size(), stream); + + return results; +} } // namespace @@ -96,10 +90,7 @@ std::unique_ptr contains_re( rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); - - return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, false}, input, stream, mr); + return contains_impl(input, pattern, flags, false, stream, mr); } std::unique_ptr matches_re( @@ -109,21 +100,18 @@ std::unique_ptr matches_re( rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); - - return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr); + return contains_impl(input, pattern, flags, true, stream, mr); } -std::unique_ptr count_re(strings_column_view const& input, - std::string const& pattern, - regex_flags const flags, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr count_re( + strings_column_view const& input, + std::string const& pattern, + regex_flags const flags, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); auto const d_strings = column_device_view::create(input.parent(), stream); diff --git a/cpp/src/strings/count_matches.cu b/cpp/src/strings/count_matches.cu index a850315dfec..d807482a3a7 100644 --- a/cpp/src/strings/count_matches.cu +++ b/cpp/src/strings/count_matches.cu @@ -15,41 +15,35 @@ */ #include -#include -#include +#include #include #include #include -#include - -#include -#include - namespace cudf { namespace strings { namespace detail { namespace { /** - * @brief Functor counts the total matches to the given regex in each string. + * @brief Kernel counts the total matches for the given regex in each string. */ -template -struct count_matches_fn { +struct count_fn { column_device_view const d_strings; - reprog_device prog; - __device__ size_type operator()(size_type idx) + __device__ int32_t operator()(size_type const idx, + reprog_device const prog, + int32_t const thread_idx) { - if (d_strings.is_null(idx)) { return 0; } - size_type count = 0; + if (d_strings.is_null(idx)) return 0; auto const d_str = d_strings.element(idx); auto const nchars = d_str.length(); + int32_t count = 0; - int32_t begin = 0; - int32_t end = nchars; - while ((begin < end) && (prog.find(idx, d_str, begin, end) > 0)) { + size_type begin = 0; + size_type end = nchars; + while ((begin < end) && (prog.find(thread_idx, d_str, begin, end) > 0)) { ++count; begin = end + (begin == end); end = nchars; @@ -58,41 +52,26 @@ struct count_matches_fn { } }; -struct count_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(column_device_view const& d_strings, - size_type output_size, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - assert(output_size >= d_strings.size() and "Unexpected output size"); - - auto results = make_numeric_column( - data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr); - - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(d_strings.size()), - results->mutable_view().data(), - count_matches_fn{d_strings, d_prog}); - return results; - } -}; - } // namespace -/** - * @copydoc cudf::strings::detail::count_matches - */ std::unique_ptr count_matches(column_device_view const& d_strings, - reprog_device const& d_prog, + reprog_device& d_prog, size_type output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, output_size, stream, mr); + assert(output_size >= d_strings.size() and "Unexpected output size"); + + auto results = make_numeric_column( + data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr); + + if (d_strings.size() == 0) return results; + + auto d_results = results->mutable_view().data(); + + launch_transform_kernel(count_fn{d_strings}, d_prog, d_results, d_strings.size(), stream); + + return results; } } // namespace detail diff --git a/cpp/src/strings/count_matches.hpp b/cpp/src/strings/count_matches.hpp index efff3958c65..d4bcdaf4042 100644 --- a/cpp/src/strings/count_matches.hpp +++ b/cpp/src/strings/count_matches.hpp @@ -39,10 +39,11 @@ class reprog_device; * @param output_size Number of rows for the output column. * @param stream CUDA stream used for device memory operations and kernel launches. * @param mr Device memory resource used to allocate the returned column's device memory. + * @return Integer column of match counts */ std::unique_ptr count_matches( column_device_view const& d_strings, - reprog_device const& d_prog, + reprog_device& d_prog, size_type output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index 9e987cf5879..59b90952d97 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -14,9 +14,7 @@ * limitations under the License. */ -#include -#include -#include +#include #include #include @@ -31,7 +29,7 @@ #include #include -#include +#include #include #include #include @@ -47,28 +45,26 @@ using string_index_pair = thrust::pair; /** * @brief This functor handles extracting strings by applying the compiled regex pattern * and creating string_index_pairs for all the substrings. - * - * @tparam stack_size Correlates to the regex instructions state to maintain for each string. - * Each instruction requires a fixed amount of overhead data. */ -template struct extract_fn { - reprog_device prog; column_device_view const d_strings; cudf::detail::device_2dspan d_indices; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, + reprog_device const d_prog, + int32_t const prog_idx) { - auto const groups = prog.group_counts(); + auto const groups = d_prog.group_counts(); auto d_output = d_indices[idx]; if (d_strings.is_valid(idx)) { auto const d_str = d_strings.element(idx); - int32_t begin = 0; - int32_t end = -1; // handles empty strings automatically - if (prog.find(idx, d_str, begin, end) > 0) { + + size_type begin = 0; + size_type end = -1; // handles empty strings automatically + if (d_prog.find(prog_idx, d_str, begin, end) > 0) { for (auto col_idx = 0; col_idx < groups; ++col_idx) { - auto const extracted = prog.extract(idx, d_str, begin, end, col_idx); + auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, col_idx); d_output[col_idx] = [&] { if (!extracted) return string_index_pair{nullptr, 0}; auto const offset = d_str.byte_offset((*extracted).first); @@ -85,33 +81,17 @@ struct extract_fn { } }; -struct extract_dispatch_fn { - reprog_device d_prog; - - template - void operator()(column_device_view const& d_strings, - cudf::detail::device_2dspan& d_indices, - rmm::cuda_stream_view stream) - { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - extract_fn{d_prog, d_strings, d_indices}); - } -}; } // namespace // -std::unique_ptr extract( - strings_column_view const& input, - std::string const& pattern, - regex_flags const flags, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr
extract(strings_column_view const& input, + std::string const& pattern, + regex_flags const flags, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); auto const groups = d_prog->group_counts(); CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern"); @@ -121,7 +101,8 @@ std::unique_ptr
extract( cudf::detail::device_2dspan(indices.data(), input.size(), groups); auto const d_strings = column_device_view::create(input.parent(), stream); - regex_dispatcher(*d_prog, extract_dispatch_fn{*d_prog}, *d_strings, d_indices, stream); + + launch_for_each_kernel(extract_fn{*d_strings, d_indices}, *d_prog, input.size(), stream); // build a result column for each group std::vector> results(groups); diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index 7dce369a24f..95b8a43a9d4 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -30,9 +28,7 @@ #include #include -#include #include -#include #include namespace cudf { @@ -49,14 +45,14 @@ namespace { * The `d_offsets` are pre-computed to identify the location of where each * string's output groups are to be written. */ -template struct extract_fn { column_device_view const d_strings; - reprog_device d_prog; offset_type const* d_offsets; string_index_pair* d_indices; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, + reprog_device const d_prog, + int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } @@ -64,16 +60,17 @@ struct extract_fn { auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); - int32_t begin = 0; - int32_t end = d_str.length(); + size_type begin = 0; + size_type end = nchars; // match the regex - while ((begin < end) && d_prog.find(idx, d_str, begin, end) > 0) { + while ((begin < end) && d_prog.find(prog_idx, d_str, begin, end) > 0) { // extract each group into the output for (auto group_idx = 0; group_idx < groups; ++group_idx) { // result is an optional containing the bounds of the extracted string at group_idx - auto const extracted = d_prog.extract(idx, d_str, begin, end, group_idx); + auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, group_idx); d_output[group_idx + output_idx] = [&] { if (!extracted) { return string_index_pair{nullptr, 0}; } @@ -84,33 +81,12 @@ struct extract_fn { } // continue to next match begin = end; - end = d_str.length(); + end = nchars; output_idx += groups; } } }; -struct extract_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(column_device_view const& d_strings, - size_type total_groups, - offset_type const* d_offsets, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - rmm::device_uvector indices(total_groups, stream); - - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - extract_fn{d_strings, d_prog, d_offsets, indices.data()}); - - return make_strings_column(indices.begin(), indices.end(), stream, mr); - } -}; - } // namespace /** @@ -129,8 +105,7 @@ std::unique_ptr extract_all_record( auto const d_strings = column_device_view::create(input.parent(), stream); // Compile regex into device object. - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); + auto d_prog = reprog_device::create(pattern, flags, stream); // The extract pattern should always include groups. auto const groups = d_prog->group_counts(); CUDF_EXPECTS(groups > 0, "extract_all requires group indicators in the regex pattern."); @@ -168,8 +143,12 @@ std::unique_ptr extract_all_record( auto const total_groups = cudf::detail::get_value(offsets->view(), strings_count, stream); - auto strings_output = regex_dispatcher( - *d_prog, extract_dispatch_fn{*d_prog}, *d_strings, total_groups, d_offsets, stream, mr); + rmm::device_uvector indices(total_groups, stream); + + launch_for_each_kernel( + extract_fn{*d_strings, d_offsets, indices.data()}, *d_prog, strings_count, stream); + + auto strings_output = make_strings_column(indices.begin(), indices.end(), stream, mr); // Build the lists column from the offsets and the strings. return make_lists_column(strings_count, diff --git a/cpp/src/strings/regex/dispatcher.hpp b/cpp/src/strings/regex/dispatcher.hpp deleted file mode 100644 index 9ff51d1c979..00000000000 --- a/cpp/src/strings/regex/dispatcher.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace cudf { -namespace strings { -namespace detail { - -/** - * The stack is used to keep progress (state) on evaluating the regex instructions on each string. - * So the size of the stack is in proportion to the number of instructions in the given regex - * pattern. - * - * There are four call types based on the number of regex instructions in the given pattern. - * Small, medium, and large instruction counts can use the stack effectively. - * Smaller stack sizes execute faster. - * - * Patterns with instruction counts bigger than large use global memory rather than the stack - * for managing the evaluation state data. - * - * @tparam Functor The functor to invoke with stack size templated value. - * @tparam Ts Parameter types for the functor call. - */ -template -constexpr decltype(auto) regex_dispatcher(reprog_device d_prog, Functor f, Ts&&... args) -{ - auto const num_regex_insts = d_prog.insts_counts(); - if (num_regex_insts <= RX_SMALL_INSTS) { - return f.template operator()(std::forward(args)...); - } - if (num_regex_insts <= RX_MEDIUM_INSTS) { - return f.template operator()(std::forward(args)...); - } - if (num_regex_insts <= RX_LARGE_INSTS) { - return f.template operator()(std::forward(args)...); - } - - return f.template operator()(std::forward(args)...); -} - -} // namespace detail -} // namespace strings -} // namespace cudf diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index bcdd15bceda..5ccc70222d5 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -39,23 +39,9 @@ struct relist; using match_pair = thrust::pair; using match_result = thrust::optional; -constexpr int32_t RX_STACK_SMALL = 112; ///< fastest stack size -constexpr int32_t RX_STACK_MEDIUM = 1104; ///< faster stack size -constexpr int32_t RX_STACK_LARGE = 2560; ///< fast stack size -constexpr int32_t RX_STACK_ANY = 8; ///< slowest: uses global memory - -/** - * @brief Mapping the number of instructions to device code stack memory size. - * - * ``` - * 10128 ≈ 1000 instructions - * Formula is based on relist::data_size_for() calculation; - * Stack ≈ (8+2)*x + (x/8) = 10.125x < 11x where x is number of instructions - * ``` - */ -constexpr int32_t RX_SMALL_INSTS = (RX_STACK_SMALL / 11); -constexpr int32_t RX_MEDIUM_INSTS = (RX_STACK_MEDIUM / 11); -constexpr int32_t RX_LARGE_INSTS = (RX_STACK_LARGE / 11); +constexpr int32_t MAX_SHARED_MEM = 2048; ///< Memory size for storing prog instruction data +constexpr std::size_t MAX_WORKING_MEM = 0x01FFFFFFFF; ///< Memory size for state data +constexpr int32_t MINIMUM_THREADS = 256; // Minimum threads for computing working memory /** * @brief Regex class stored on the device and executed by reprog_device. @@ -75,6 +61,12 @@ struct alignas(16) reclass_device { * * Once created, the find/extract methods are used to evaluate the regex instructions * against a single string. + * + * An instance of the class requires working memory for evaluating the regex + * instructions for the string. Determine the size of the required memory by + * calling either `working_memory_size()` or `compute_strided_working_memory()`. + * Once the buffer is allocated, pass the device pointer to the `set_working_memory()` + * member function. */ class reprog_device { public: @@ -92,33 +84,22 @@ class reprog_device { * regex. * * @param pattern The regex pattern to compile. - * @param codepoint_flags The code point lookup table for character types. - * @param strings_count Number of strings that will be evaluated. * @param stream CUDA stream used for device memory operations and kernel launches. * @return The program device object. */ static std::unique_ptr> create( - std::string const& pattern, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream); + std::string const& pattern, rmm::cuda_stream_view stream); /** * @brief Create the device program instance from a regex pattern. * * @param pattern The regex pattern to compile. * @param re_flags Regex flags for interpreting special characters in the pattern. - * @param codepoint_flags The code point lookup table for character types. - * @param strings_count Number of strings that will be evaluated. * @param stream CUDA stream used for device memory operations and kernel launches * @return The program device object. */ static std::unique_ptr> create( - std::string const& pattern, - regex_flags const re_flags, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream); + std::string const& pattern, regex_flags const re_flags, rmm::cuda_stream_view stream); /** * @brief Called automatically by the unique_ptr returned from create(). @@ -143,12 +124,75 @@ class reprog_device { */ [[nodiscard]] __device__ inline bool is_empty() const; + /** + * @brief Returns the size needed for working memory for the given thread count. + * + * @param num_threads Number of threads to be executed in parallel + * @return Size of working memory in bytes + */ + [[nodiscard]] std::size_t working_memory_size(int32_t num_threads) const; + + /** + * @brief Compute working memory for the given thread count with a maximum size. + * + * The `min_rows` overrules the `requested_max_size`. + * That is, the `requested_max_size` may be + * exceeded to keep the number of rows greater than `min_rows`. + * Also, if `rows < min_rows` then `min_rows` is not enforced. + * + * @param rows Number of rows to execute in parallel + * @param min_rows The least number of rows to meet `max_size` + * @param requested_max_size Requested maximum bytes for the working memory + * @return The size of the working memory and the number of parallel rows it will support + */ + [[nodiscard]] std::pair compute_strided_working_memory( + int32_t rows, + int32_t min_rows = MINIMUM_THREADS, + std::size_t requested_max_size = MAX_WORKING_MEM) const; + + /** + * @brief Set the device working memory buffer to use for the regex execution. + * + * @param buffer Device memory pointer. + * @param thread_count Number of threads the memory buffer will support. + * @param max_insts Set to the maximum instruction count if reusing the + * memory buffer for other regex calls. + */ + void set_working_memory(void* buffer, int32_t thread_count, int32_t max_insts = 0); + + /** + * @brief Returns the size of shared memory required to hold this instance. + * + * This can be called on the CPU for specifying the shared-memory size in the + * kernel launch parameters. + * This may return 0 if the MAX_SHARED_MEM value is exceeded. + */ + [[nodiscard]] int32_t compute_shared_memory_size() const; + + /** + * @brief Returns the thread count passed on `set_working_memory`. + */ + [[nodiscard]] __device__ inline int32_t thread_count() const { return _thread_count; } + + /** + * @brief Store this object into the given device pointer (e.g. shared memory). + * + * No data is stored if MAX_SHARED_MEM is exceeded for this object. + */ + __device__ inline void store(void* buffer) const; + + /** + * @brief Load an instance of this class from a device buffer (e.g. shared memory). + * + * Data is loaded from the given buffer if MAX_SHARED_MEM is not exceeded for the given object. + * Otherwise, a copy of the object is returned. + */ + [[nodiscard]] __device__ static inline reprog_device load(reprog_device const prog, void* buffer); + /** * @brief Does a find evaluation using the compiled expression on the given string. * - * @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`. - * @param idx The string index used for mapping the state memory for this string in global memory - * (if necessary). + * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. * @param[in,out] begin Position index to begin the search. If found, returns the position found * in the string. @@ -156,8 +200,7 @@ class reprog_device { * matching in the string. * @return Returns 0 if no match is found. */ - template - __device__ inline int32_t find(int32_t idx, + __device__ inline int32_t find(int32_t const thread_idx, string_view const d_str, cudf::size_type& begin, cudf::size_type& end) const; @@ -169,9 +212,7 @@ class reprog_device { * The find() function should be called first to locate the begin/end bounds of the * the matched section. * - * @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`. - * @param idx The string index used for mapping the state memory for this string in global - * memory (if necessary). + * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. * @param begin Position index to begin the search. If found, returns the position found * in the string. @@ -180,8 +221,7 @@ class reprog_device { * @param group_id The specific group to return its matching position values. * @return If valid, returns the character position of the matched group in the given string, */ - template - __device__ inline match_result extract(cudf::size_type idx, + __device__ inline match_result extract(int32_t const thread_idx, string_view const d_str, cudf::size_type begin, cudf::size_type end, @@ -220,8 +260,7 @@ class reprog_device { /** * @brief Utility wrapper to setup state memory structures for calling regexec */ - template - __device__ inline int32_t call_regexec(int32_t idx, + __device__ inline int32_t call_regexec(int32_t const thread_idx, string_view const d_str, cudf::size_type& begin, cudf::size_type& end, @@ -234,13 +273,16 @@ class reprog_device { int32_t _insts_count; // number of instructions int32_t _starts_count; // number of start-insts ids int32_t _classes_count; // number of classes + int32_t _max_insts; // for partitioning working memory uint8_t const* _codepoint_flags{}; // table of character types reinst const* _insts{}; // array of regex instructions int32_t const* _startinst_ids{}; // array of start instruction ids reclass_device const* _classes{}; // array of regex classes - void* _relists_mem{}; // runtime relist memory for regexec() + std::size_t _prog_size{}; // total size of this instance + void* _buffer{}; // working memory buffer + int32_t _thread_count{}; // threads available in working memory }; } // namespace detail diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index 9fe4440d7ec..bae6fb275f6 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -45,10 +45,9 @@ struct alignas(8) relist { /** * @brief Compute the aligned memory allocation size. */ - constexpr inline static std::size_t alloc_size(int32_t insts) + constexpr inline static std::size_t alloc_size(int32_t insts, int32_t num_threads) { - return cudf::util::round_up_unsafe(data_size_for(insts) + sizeof(relist), - sizeof(ranges[0])); + return cudf::util::round_up_unsafe(data_size_for(insts) * num_threads, sizeof(restate)); } struct alignas(16) restate { @@ -57,16 +56,16 @@ struct alignas(8) relist { int32_t reserved; }; - __device__ __forceinline__ relist(int16_t insts, u_char* data = nullptr) - : masksize(cudf::util::div_rounding_up_unsafe(insts, 8)) + __device__ __forceinline__ + relist(int16_t insts, int32_t num_threads, u_char* gp_ptr, int32_t index) + : masksize(cudf::util::div_rounding_up_unsafe(insts, 8)), stride(num_threads) { - auto ptr = data == nullptr ? reinterpret_cast(this) + sizeof(relist) : data; - ranges = reinterpret_cast(ptr); - ptr += insts * sizeof(ranges[0]); - inst_ids = reinterpret_cast(ptr); - ptr += insts * sizeof(inst_ids[0]); - mask = ptr; - reset(); + auto const rdata_size = sizeof(ranges[0]); + auto const idata_size = sizeof(inst_ids[0]); + ranges = reinterpret_cast(gp_ptr + (index * rdata_size)); + inst_ids = + reinterpret_cast(gp_ptr + (rdata_size * stride * insts) + (index * idata_size)); + mask = gp_ptr + ((rdata_size + idata_size) * stride * insts) + (index * masksize); } __device__ __forceinline__ void reset() @@ -79,15 +78,15 @@ struct alignas(8) relist { { if (readMask(id)) { return false; } writeMask(id); - inst_ids[size] = static_cast(id); - ranges[size] = int2{begin, end}; + inst_ids[size * stride] = static_cast(id); + ranges[size * stride] = int2{begin, end}; ++size; return true; } __device__ __forceinline__ restate get_state(int16_t idx) const { - return restate{ranges[idx], inst_ids[idx]}; + return restate{ranges[idx * stride], inst_ids[idx * stride]}; } __device__ __forceinline__ int16_t get_size() const { return size; } @@ -95,7 +94,7 @@ struct alignas(8) relist { private: int16_t size{}; int16_t const masksize; - int32_t reserved; + int32_t const stride; int2* __restrict__ ranges; // pair per instruction int16_t* __restrict__ inst_ids; // one per instruction u_char* __restrict__ mask; // bit per instruction @@ -177,6 +176,49 @@ __device__ __forceinline__ bool reprog_device::is_empty() const return insts_counts() == 0 || get_inst(0).type == END; } +__device__ __forceinline__ void reprog_device::store(void* buffer) const +{ + if (_prog_size > MAX_SHARED_MEM) { return; } + + auto ptr = static_cast(buffer); + + // create instance inside the given buffer + auto result = new (ptr) reprog_device(*this); + + // add the insts array + ptr += sizeof(reprog_device); + auto insts = reinterpret_cast(ptr); + result->_insts = insts; + for (int idx = 0; idx < _insts_count; ++idx) + *insts++ = _insts[idx]; + + // add the startinst_ids array + ptr += cudf::util::round_up_unsafe(_insts_count * sizeof(_insts[0]), sizeof(_startinst_ids[0])); + auto ids = reinterpret_cast(ptr); + result->_startinst_ids = ids; + for (int idx = 0; idx < _starts_count; ++idx) + *ids++ = _startinst_ids[idx]; + + // add the classes array + ptr += cudf::util::round_up_unsafe(_starts_count * sizeof(int32_t), sizeof(_classes[0])); + auto classes = reinterpret_cast(ptr); + result->_classes = classes; + // fill in each class + auto d_ptr = reinterpret_cast(classes + _classes_count); + for (int idx = 0; idx < _classes_count; ++idx) { + classes[idx] = _classes[idx]; + classes[idx].literals = d_ptr; + for (int jdx = 0; jdx < _classes[idx].count * 2; ++jdx) + *d_ptr++ = _classes[idx].literals[jdx]; + } +} + +__device__ __forceinline__ reprog_device reprog_device::load(reprog_device const prog, void* buffer) +{ + return (prog._prog_size > MAX_SHARED_MEM) ? reprog_device(prog) + : reinterpret_cast(buffer)[0]; +} + /** * @brief Evaluate a specific string against regex pattern compiled to this instance. * @@ -352,65 +394,43 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr return match; } -template -__device__ __forceinline__ int32_t reprog_device::find(int32_t idx, +__device__ __forceinline__ int32_t reprog_device::find(int32_t const thread_idx, string_view const dstr, cudf::size_type& begin, cudf::size_type& end) const { - int32_t rtn = call_regexec(idx, dstr, begin, end); + auto const rtn = call_regexec(thread_idx, dstr, begin, end); if (rtn <= 0) begin = end = -1; return rtn; } -template -__device__ __forceinline__ match_result reprog_device::extract(cudf::size_type idx, +__device__ __forceinline__ match_result reprog_device::extract(int32_t const thread_idx, string_view const dstr, cudf::size_type begin, cudf::size_type end, cudf::size_type const group_id) const { end = begin + 1; - return call_regexec(idx, dstr, begin, end, group_id + 1) > 0 - ? match_result({begin, end}) - : thrust::nullopt; + return call_regexec(thread_idx, dstr, begin, end, group_id + 1) > 0 ? match_result({begin, end}) + : thrust::nullopt; } -template -__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t idx, +__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t const thread_idx, string_view const dstr, cudf::size_type& begin, cudf::size_type& end, cudf::size_type const group_id) const { - u_char data1[stack_size], data2[stack_size]; + auto gp_ptr = reinterpret_cast(_buffer); + relist list1(static_cast(_max_insts), _thread_count, gp_ptr, thread_idx); - relist list1(static_cast(_insts_count), data1); - relist list2(static_cast(_insts_count), data2); + gp_ptr += relist::alloc_size(_max_insts, _thread_count); + relist list2(static_cast(_max_insts), _thread_count, gp_ptr, thread_idx); reljunk jnk(&list1, &list2, get_inst(_startinst_id)); return regexec(dstr, jnk, begin, end, group_id); } -template <> -__device__ __forceinline__ int32_t -reprog_device::call_regexec(int32_t idx, - string_view const dstr, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id) const -{ - auto const relists_size = relist::alloc_size(_insts_count); - auto* listmem = reinterpret_cast(_relists_mem); // beginning of relist buffer; - listmem += (idx * relists_size * 2); // two relist ptrs in reljunk: - - auto* list1 = new (listmem) relist(static_cast(_insts_count)); - auto* list2 = new (listmem + relists_size) relist(static_cast(_insts_count)); - - reljunk jnk(list1, list2, get_inst(_startinst_id)); - return regexec(dstr, jnk, begin, end, group_id); -} - } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index 70d6079972a..4b58d9d8a88 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -35,27 +36,21 @@ reprog_device::reprog_device(reprog& prog) _num_capturing_groups{prog.groups_count()}, _insts_count{prog.insts_count()}, _starts_count{prog.starts_count()}, - _classes_count{prog.classes_count()} + _classes_count{prog.classes_count()}, + _max_insts{prog.insts_count()}, + _codepoint_flags{get_character_flags_table()} { } std::unique_ptr> reprog_device::create( - std::string const& pattern, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream) + std::string const& pattern, rmm::cuda_stream_view stream) { - return reprog_device::create( - pattern, regex_flags::MULTILINE, codepoint_flags, strings_count, stream); + return reprog_device::create(pattern, regex_flags::MULTILINE, stream); } // Create instance of the reprog that can be passed into a device kernel std::unique_ptr> reprog_device::create( - std::string const& pattern, - regex_flags const flags, - uint8_t const* codepoint_flags, - size_type strings_count, - rmm::cuda_stream_view stream) + std::string const& pattern, regex_flags const flags, rmm::cuda_stream_view stream) { // compile pattern into host object reprog h_prog = reprog::create_from(pattern, flags); @@ -82,7 +77,7 @@ std::unique_ptr> reprog_devic auto d_buffer = new rmm::device_buffer(memsize, stream); // output device memory; auto d_ptr = reinterpret_cast(d_buffer->data()); // running device pointer - // put everything into a flat host buffer first + // create our device object; this is managed separately and returned to the caller reprog_device* d_prog = new reprog_device(h_prog); // copy the instructions array first (fixed-sized structs) @@ -120,32 +115,58 @@ std::unique_ptr> reprog_devic } // initialize the rest of the elements - d_prog->_codepoint_flags = codepoint_flags; - - // allocate execute memory if needed - rmm::device_buffer* d_relists{}; - if (insts_count > RX_LARGE_INSTS) { - // two relist state structures are needed for execute per string - auto const rlm_size = relist::alloc_size(insts_count) * 2 * strings_count; - d_relists = new rmm::device_buffer(rlm_size, stream); - d_prog->_relists_mem = d_relists->data(); - } + d_prog->_max_insts = insts_count; + d_prog->_prog_size = memsize + sizeof(reprog_device); // copy flat prog to device memory CUDF_CUDA_TRY(cudaMemcpyAsync( d_buffer->data(), h_buffer.data(), memsize, cudaMemcpyHostToDevice, stream.value())); // build deleter to cleanup device memory - auto deleter = [d_buffer, d_relists](reprog_device* t) { + auto deleter = [d_buffer](reprog_device* t) { t->destroy(); delete d_buffer; - delete d_relists; }; + return std::unique_ptr>(d_prog, deleter); } void reprog_device::destroy() { delete this; } +std::size_t reprog_device::working_memory_size(int32_t num_threads) const +{ + return relist::alloc_size(_insts_count, num_threads) * 2; +} + +std::pair reprog_device::compute_strided_working_memory( + int32_t rows, int32_t min_rows, std::size_t requested_max_size) const +{ + auto thread_count = rows; + auto buffer_size = working_memory_size(thread_count); + while ((buffer_size > requested_max_size) && (thread_count > min_rows)) { + thread_count = thread_count / 2; + buffer_size = working_memory_size(thread_count); + } + // clamp to min_rows but only if rows is greater than min_rows + if (rows > min_rows && thread_count < min_rows) { + thread_count = min_rows; + buffer_size = working_memory_size(thread_count); + } + return std::make_pair(buffer_size, thread_count); +} + +void reprog_device::set_working_memory(void* buffer, int32_t thread_count, int32_t max_insts) +{ + _buffer = buffer; + _thread_count = thread_count; + _max_insts = _max_insts > 0 ? _max_insts : _insts_count; +} + +int32_t reprog_device::compute_shared_memory_size() const +{ + return _prog_size < MAX_SHARED_MEM ? static_cast(_prog_size) : 0; +} + } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/regex/utilities.cuh b/cpp/src/strings/regex/utilities.cuh new file mode 100644 index 00000000000..9a80be25b3b --- /dev/null +++ b/cpp/src/strings/regex/utilities.cuh @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +#include +#include + +#include + +namespace cudf { +namespace strings { +namespace detail { + +constexpr auto regex_launch_kernel_block_size = 256; + +template +__global__ void for_each_kernel(ForEachFunction fn, reprog_device const d_prog, size_type size) +{ + extern __shared__ u_char shmem[]; + if (threadIdx.x == 0) { d_prog.store(shmem); } + __syncthreads(); + auto const s_prog = reprog_device::load(d_prog, shmem); + + auto const thread_idx = threadIdx.x + blockIdx.x * blockDim.x; + auto const stride = s_prog.thread_count(); + for (auto idx = thread_idx; idx < size; idx += stride) { + fn(idx, s_prog, thread_idx); + } +} + +template +void launch_for_each_kernel(ForEachFunction fn, + reprog_device& d_prog, + size_type size, + rmm::cuda_stream_view stream) +{ + auto [buffer_size, thread_count] = d_prog.compute_strided_working_memory(size); + + auto d_buffer = rmm::device_buffer(buffer_size, stream); + d_prog.set_working_memory(d_buffer.data(), thread_count); + + auto const shmem_size = d_prog.compute_shared_memory_size(); + cudf::detail::grid_1d grid{thread_count, regex_launch_kernel_block_size}; + for_each_kernel<<>>( + fn, d_prog, size); +} + +template +__global__ void transform_kernel(TransformFunction fn, + reprog_device const d_prog, + OutputType* d_output, + size_type size) +{ + extern __shared__ u_char shmem[]; + if (threadIdx.x == 0) { d_prog.store(shmem); } + __syncthreads(); + auto const s_prog = reprog_device::load(d_prog, shmem); + + auto const thread_idx = threadIdx.x + blockIdx.x * blockDim.x; + auto const stride = s_prog.thread_count(); + for (auto idx = thread_idx; idx < size; idx += stride) { + d_output[idx] = fn(idx, s_prog, thread_idx); + } +} + +template +void launch_transform_kernel(TransformFunction fn, + reprog_device& d_prog, + OutputType* d_output, + size_type size, + rmm::cuda_stream_view stream) +{ + auto [buffer_size, thread_count] = d_prog.compute_strided_working_memory(size); + + auto d_buffer = rmm::device_buffer(buffer_size, stream); + d_prog.set_working_memory(d_buffer.data(), thread_count); + + auto const shmem_size = d_prog.compute_shared_memory_size(); + cudf::detail::grid_1d grid{thread_count, regex_launch_kernel_block_size}; + transform_kernel<<>>( + fn, d_prog, d_output, size); +} + +template +auto make_strings_children(SizeAndExecuteFunction size_and_exec_fn, + reprog_device& d_prog, + size_type strings_count, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto offsets = make_numeric_column( + data_type{type_id::INT32}, strings_count + 1, mask_state::UNALLOCATED, stream, mr); + auto d_offsets = offsets->mutable_view().template data(); + size_and_exec_fn.d_offsets = d_offsets; + + auto [buffer_size, thread_count] = d_prog.compute_strided_working_memory(strings_count); + + auto d_buffer = rmm::device_buffer(buffer_size, stream); + d_prog.set_working_memory(d_buffer.data(), thread_count); + auto const shmem_size = d_prog.compute_shared_memory_size(); + cudf::detail::grid_1d grid{thread_count, 256}; + + // Compute the output size for each row + if (strings_count > 0) { + for_each_kernel<<>>( + size_and_exec_fn, d_prog, strings_count); + } + + // Convert sizes to offsets + thrust::exclusive_scan( + rmm::exec_policy(stream), d_offsets, d_offsets + strings_count + 1, d_offsets); + + // Now build the chars column + auto const char_bytes = cudf::detail::get_value(offsets->view(), strings_count, stream); + std::unique_ptr chars = create_chars_child_column(char_bytes, stream, mr); + if (char_bytes > 0) { + size_and_exec_fn.d_chars = chars->mutable_view().template data(); + for_each_kernel<<>>( + size_and_exec_fn, d_prog, strings_count); + } + + return std::make_pair(std::move(offsets), std::move(chars)); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index 384813d6e3d..107adf07263 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -16,9 +16,7 @@ #include "backref_re.cuh" -#include -#include -#include +#include #include #include @@ -43,7 +41,7 @@ namespace { * @brief Return the capturing group index pattern to use with the given replacement string. * * Only two patterns are supported at this time `\d` and `${d}` where `d` is an integer in - * the range 1-99. The `\d` pattern is returned by default unless no `\d` pattern is found in + * the range 0-99. The `\d` pattern is returned by default unless no `\d` pattern is found in * the `repl` string, * * Reference: https://www.regular-expressions.info/refreplacebackref.html @@ -98,45 +96,15 @@ std::pair> parse_backrefs(std::string con return {rtn, backrefs}; } -template -struct replace_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(strings_column_view const& input, - string_view const& d_repl_template, - Iterator backrefs_begin, - Iterator backrefs_end, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto const d_strings = column_device_view::create(input.parent(), stream); - - auto children = make_strings_children( - backrefs_fn{ - *d_strings, d_prog, d_repl_template, backrefs_begin, backrefs_end}, - input.size(), - stream, - mr); - - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); - } -}; - } // namespace // -std::unique_ptr replace_with_backrefs( - strings_column_view const& input, - std::string const& pattern, - std::string const& replacement, - regex_flags const flags, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr replace_with_backrefs(strings_column_view const& input, + std::string const& pattern, + std::string const& replacement, + regex_flags const flags, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { if (input.is_empty()) return make_empty_column(type_id::STRING); @@ -144,8 +112,7 @@ std::unique_ptr replace_with_backrefs( CUDF_EXPECTS(!replacement.empty(), "Parameter replacement must not be empty"); // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); // parse the repl string for back-ref indicators auto group_count = std::min(99, d_prog->group_counts()); // group count should NOT exceed 99 @@ -155,15 +122,21 @@ std::unique_ptr replace_with_backrefs( string_scalar repl_scalar(parse_result.first, true, stream); string_view const d_repl_template = repl_scalar.value(); + auto const d_strings = column_device_view::create(input.parent(), stream); + using BackRefIterator = decltype(backrefs.begin()); - return regex_dispatcher(*d_prog, - replace_dispatch_fn{*d_prog}, - input, - d_repl_template, - backrefs.begin(), - backrefs.end(), - stream, - mr); + auto children = make_strings_children( + backrefs_fn{*d_strings, d_repl_template, backrefs.begin(), backrefs.end()}, + *d_prog, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); } } // namespace detail diff --git a/cpp/src/strings/replace/backref_re.cuh b/cpp/src/strings/replace/backref_re.cuh index 13a67e3b4d7..db5b8a1eb17 100644 --- a/cpp/src/strings/replace/backref_re.cuh +++ b/cpp/src/strings/replace/backref_re.cuh @@ -14,13 +14,13 @@ * limitations under the License. */ +#include + #include #include #include #include -#include - #include #include @@ -39,17 +39,16 @@ using backref_type = thrust::pair; * * The logic includes computing the size of each string and also writing the output. */ -template +template struct backrefs_fn { column_device_view const d_strings; - reprog_device prog; string_view const d_repl; // string replacement template Iterator backrefs_begin; Iterator backrefs_end; int32_t* d_offsets{}; char* d_chars{}; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { if (!d_chars) d_offsets[idx] = 0; @@ -65,7 +64,7 @@ struct backrefs_fn { size_type end = nchars; // last character position (exclusive) // copy input to output replacing strings as we go - while (prog.find(idx, d_str, begin, end) > 0) // inits the begin/end vars + while (prog.find(prog_idx, d_str, begin, end) > 0) // inits the begin/end vars { auto spos = d_str.byte_offset(begin); // get offset for the auto epos = d_str.byte_offset(end); // character position values; @@ -84,7 +83,7 @@ struct backrefs_fn { lpos_template += copy_length; } // extract the specific group's string for this backref's index - auto extracted = prog.extract(idx, d_str, begin, end, backref.first - 1); + auto extracted = prog.extract(prog_idx, d_str, begin, end, backref.first - 1); if (!extracted || (extracted.value().second <= extracted.value().first)) { return; // no value for this backref number; that is ok } diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index 3189739e492..a3f2631f424 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -14,9 +14,7 @@ * limitations under the License. */ -#include #include -#include #include #include @@ -32,6 +30,7 @@ #include #include +#include #include #include @@ -47,7 +46,6 @@ using found_range = thrust::pair; * @brief This functor handles replacing strings by applying the compiled regex patterns * and inserting the corresponding new string within the matched range of characters. */ -template struct replace_multi_regex_fn { column_device_view const d_strings; device_span progs; // array of regex progs @@ -84,9 +82,9 @@ struct replace_multi_regex_fn { continue; // or later in the string reprog_device prog = progs[ptn_idx]; - auto begin = static_cast(ch_pos); - auto end = static_cast(nchars); - if (!prog.is_empty() && prog.find(idx, d_str, begin, end) > 0) + auto begin = ch_pos; + auto end = nchars; + if (!prog.is_empty() && prog.find(idx, d_str, begin, end) > 0) d_ranges[ptn_idx] = found_range{begin, end}; // found a match else d_ranges[ptn_idx] = found_range{nchars, nchars}; // this pattern is done @@ -123,33 +121,6 @@ struct replace_multi_regex_fn { } }; -struct replace_dispatch_fn { - template - std::unique_ptr operator()(strings_column_view const& input, - device_span d_progs, - strings_column_view const& replacements, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto const d_strings = column_device_view::create(input.parent(), stream); - auto const d_repls = column_device_view::create(replacements.parent(), stream); - - auto found_ranges = rmm::device_uvector(d_progs.size() * input.size(), stream); - - auto children = make_strings_children( - replace_multi_regex_fn{*d_strings, d_progs, found_ranges.data(), *d_repls}, - input.size(), - stream, - mr); - - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); - } -}; - } // namespace std::unique_ptr replace_re( @@ -168,15 +139,12 @@ std::unique_ptr replace_re( CUDF_EXPECTS(!replacements.has_nulls(), "Parameter replacements must not have any nulls"); // compile regexes into device objects - auto const d_char_table = get_character_flags_table(); auto h_progs = std::vector>>( patterns.size()); - std::transform(patterns.begin(), - patterns.end(), - h_progs.begin(), - [flags, d_char_table, input, stream](auto const& ptn) { - return reprog_device::create(ptn, flags, d_char_table, input.size(), stream); - }); + std::transform( + patterns.begin(), patterns.end(), h_progs.begin(), [flags, stream](auto const& ptn) { + return reprog_device::create(ptn, flags, stream); + }); // get the longest regex for the dispatcher auto const max_prog = @@ -184,15 +152,37 @@ std::unique_ptr replace_re( return lhs->insts_counts() < rhs->insts_counts(); }); + auto d_max_prog = **max_prog; + auto const buffer_size = d_max_prog.working_memory_size(input.size()); + auto d_buffer = rmm::device_buffer(buffer_size, stream); + // copy all the reprog_device instances to a device memory array std::vector progs; - std::transform(h_progs.begin(), h_progs.end(), std::back_inserter(progs), [](auto const& d_prog) { - return *d_prog; - }); + std::transform(h_progs.begin(), + h_progs.end(), + std::back_inserter(progs), + [d_buffer = d_buffer.data(), size = input.size()](auto& prog) { + prog->set_working_memory(d_buffer, size); + return *prog; + }); auto d_progs = cudf::detail::make_device_uvector_async(progs, stream); - return regex_dispatcher( - **max_prog, replace_dispatch_fn{}, input, d_progs, replacements, stream, mr); + auto const d_strings = column_device_view::create(input.parent(), stream); + auto const d_repls = column_device_view::create(replacements.parent(), stream); + + auto found_ranges = rmm::device_uvector(d_progs.size() * input.size(), stream); + + auto children = make_strings_children( + replace_multi_regex_fn{*d_strings, d_progs, found_ranges.data(), *d_repls}, + input.size(), + stream, + mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); } } // namespace detail diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index af74d8bdb92..159f83453bd 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -14,9 +14,7 @@ * limitations under the License. */ -#include -#include -#include +#include #include #include @@ -38,16 +36,14 @@ namespace { * @brief This functor handles replacing strings by applying the compiled regex pattern * and inserting the new string within the matched range of characters. */ -template struct replace_regex_fn { column_device_view const d_strings; - reprog_device prog; string_view const d_repl; size_type const maxrepl; int32_t* d_offsets{}; char* d_chars{}; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { if (!d_chars) d_offsets[idx] = 0; @@ -62,13 +58,13 @@ struct replace_regex_fn { auto out_ptr = d_chars ? d_chars + d_offsets[idx] // output pointer (o) : nullptr; size_type last_pos = 0; - int32_t begin = 0; // these are for calling prog.find - int32_t end = -1; // matches final word-boundary if at the end of the string + size_type begin = 0; // these are for calling prog.find + size_type end = -1; // matches final word-boundary if at the end of the string // copy input to output replacing strings as we go while (mxn-- > 0 && begin <= nchars) { // maximum number of replaces - if (prog.is_empty() || prog.find(idx, d_str, begin, end) <= 0) { + if (prog.is_empty() || prog.find(prog_idx, d_str, begin, end) <= 0) { break; // no more matches } @@ -100,32 +96,6 @@ struct replace_regex_fn { } }; -struct replace_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(strings_column_view const& input, - string_view const& d_replacement, - size_type max_replace_count, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto const d_strings = column_device_view::create(input.parent(), stream); - - auto children = make_strings_children( - replace_regex_fn{*d_strings, d_prog, d_replacement, max_replace_count}, - input.size(), - stream, - mr); - - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); - } -}; - } // namespace // @@ -144,13 +114,20 @@ std::unique_ptr replace_re( string_view d_repl(replacement.data(), replacement.size()); // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + auto d_prog = reprog_device::create(pattern, flags, stream); auto const maxrepl = max_replace_count.value_or(-1); - return regex_dispatcher( - *d_prog, replace_dispatch_fn{*d_prog}, input, d_repl, maxrepl, stream, mr); + auto const d_strings = column_device_view::create(input.parent(), stream); + + auto children = make_strings_children( + replace_regex_fn{*d_strings, d_repl, maxrepl}, *d_prog, input.size(), stream, mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); } } // namespace detail diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 323ad2cbc09..64e46d07e25 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -33,7 +31,6 @@ #include #include -#include #include #include @@ -52,14 +49,12 @@ namespace { * For strings with fewer matches, null entries are appended into `d_indices` * up to the maximum column count. */ -template struct findall_fn { column_device_view const d_strings; - reprog_device prog; size_type const* d_counts; ///< match counts for each string indices_span d_indices; ///< 2D-span: output matches added here - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { auto const match_count = d_counts[idx]; @@ -72,7 +67,7 @@ struct findall_fn { int32_t begin = 0; int32_t end = -1; for (auto col_idx = 0; col_idx < match_count; ++col_idx) { - if (prog.find(idx, d_str, begin, end) > 0) { + if (prog.find(prog_idx, d_str, begin, end) > 0) { auto const begin_offset = d_str.byte_offset(begin); auto const end_offset = d_str.byte_offset(end); d_output[col_idx] = @@ -82,28 +77,12 @@ struct findall_fn { end = nchars; } } - // fill the remaining entries for this row with nulls thrust::fill( thrust::seq, d_output.begin() + match_count, d_output.end(), string_index_pair{nullptr, 0}); } }; -struct findall_dispatch_fn { - reprog_device d_prog; - - template - void operator()(column_device_view const& d_strings, - size_type const* d_find_counts, - indices_span& d_indices, - rmm::cuda_stream_view stream) - { - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - findall_fn{d_strings, d_prog, d_find_counts, d_indices}); - } -}; } // namespace std::unique_ptr
findall(strings_column_view const& input, @@ -115,11 +94,10 @@ std::unique_ptr
findall(strings_column_view const& input, auto const strings_count = input.size(); // compile regex into device object - auto const d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); + auto const d_prog = reprog_device::create(pattern, flags, stream); auto const d_strings = column_device_view::create(input.parent(), stream); - auto find_counts = count_matches(*d_strings, *d_prog, strings_count + 1, stream); + auto find_counts = count_matches(*d_strings, *d_prog, strings_count, stream); auto d_find_counts = find_counts->view().data(); size_type const columns_count = thrust::reduce( @@ -139,9 +117,8 @@ std::unique_ptr
findall(strings_column_view const& input, } else { // place all matching strings into the indices vector auto d_indices = indices_span(indices.data(), strings_count, columns_count); - regex_dispatcher( - *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, d_find_counts, d_indices, stream); - + launch_for_each_kernel( + findall_fn{*d_strings, d_find_counts, d_indices}, *d_prog, strings_count, stream); results.resize(columns_count); } diff --git a/cpp/src/strings/search/findall_record.cu b/cpp/src/strings/search/findall_record.cu index 46155bd7cf5..2f4b9ce5b24 100644 --- a/cpp/src/strings/search/findall_record.cu +++ b/cpp/src/strings/search/findall_record.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -32,8 +30,6 @@ #include #include -#include -#include #include #include @@ -49,55 +45,48 @@ namespace { * @brief This functor handles extracting matched strings by applying the compiled regex pattern * and creating string_index_pairs for all the substrings. */ -template struct findall_fn { column_device_view const d_strings; - reprog_device prog; offset_type const* d_offsets; string_index_pair* d_indices; - __device__ void operator()(size_type const idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - int32_t begin = 0; - int32_t end = d_str.length(); - while ((begin < end) && (prog.find(idx, d_str, begin, end) > 0)) { + size_type begin = 0; + size_type end = nchars; + while ((begin < end) && (prog.find(prog_idx, d_str, begin, end) > 0)) { auto const spos = d_str.byte_offset(begin); // convert auto const epos = d_str.byte_offset(end); // to bytes d_output[output_idx++] = string_index_pair{d_str.data() + spos, (epos - spos)}; begin = end + (begin == end); - end = d_str.length(); + end = nchars; } } }; -struct findall_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(column_device_view const& d_strings, +std::unique_ptr findall_util(column_device_view const& d_strings, + reprog_device& d_prog, size_type total_matches, offset_type const* d_offsets, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) - { - rmm::device_uvector indices(total_matches, stream); +{ + rmm::device_uvector indices(total_matches, stream); - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - findall_fn{d_strings, d_prog, d_offsets, indices.data()}); + launch_for_each_kernel( + findall_fn{d_strings, d_offsets, indices.data()}, d_prog, d_strings.size(), stream); - return make_strings_column(indices.begin(), indices.end(), stream, mr); - } -}; + return make_strings_column(indices.begin(), indices.end(), stream, mr); +} } // namespace @@ -113,8 +102,7 @@ std::unique_ptr findall_record( auto const d_strings = column_device_view::create(input.parent(), stream); // compile regex into device object - auto const d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); + auto const d_prog = reprog_device::create(pattern, flags, stream); // Create lists offsets column auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr); @@ -128,8 +116,7 @@ std::unique_ptr findall_record( auto const total_matches = cudf::detail::get_value(offsets->view(), strings_count, stream); - auto strings_output = regex_dispatcher( - *d_prog, findall_dispatch_fn{*d_prog}, *d_strings, total_matches, d_offsets, stream, mr); + auto strings_output = findall_util(*d_strings, *d_prog, total_matches, d_offsets, stream, mr); // Build the lists column from the offsets and the strings return make_lists_column(strings_count, diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index 3ec6df058c6..16edd0606e9 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -15,9 +15,7 @@ */ #include -#include -#include -#include +#include #include #include @@ -28,12 +26,10 @@ #include #include #include -#include #include #include -#include #include #include #include @@ -59,18 +55,17 @@ enum class split_direction { * The `d_token_offsets` specifies the output position within `d_tokens` * for each string. */ -template struct token_reader_fn { column_device_view const d_strings; - reprog_device prog; split_direction const direction; offset_type const* d_token_offsets; string_index_pair* d_tokens; - __device__ void operator()(size_type idx) + __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); auto const token_offset = d_token_offsets[idx]; auto const token_count = d_token_offsets[idx + 1] - token_offset; @@ -78,9 +73,9 @@ struct token_reader_fn { size_type token_idx = 0; size_type begin = 0; // characters - size_type end = d_str.length(); + size_type end = nchars; size_type last_pos = 0; // bytes - while (prog.find(idx, d_str, begin, end) > 0) { + while (prog.find(prog_idx, d_str, begin, end) > 0) { // get the token (characters just before this match) auto const token = string_index_pair{d_str.data() + last_pos, d_str.byte_offset(begin) - last_pos}; @@ -97,7 +92,7 @@ struct token_reader_fn { // setup for next match last_pos = d_str.byte_offset(end); begin = end + (begin == end); - end = d_str.length(); + end = nchars; } // set the last token to the remainder of the string @@ -116,28 +111,6 @@ struct token_reader_fn { } }; -struct generate_dispatch_fn { - reprog_device d_prog; - - template - rmm::device_uvector operator()(column_device_view const& d_strings, - size_type total_tokens, - split_direction direction, - offset_type const* d_offsets, - rmm::cuda_stream_view stream) - { - rmm::device_uvector tokens(total_tokens, stream); - - thrust::for_each_n( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - d_strings.size(), - token_reader_fn{d_strings, d_prog, direction, d_offsets, tokens.data()}); - - return tokens; - } -}; - /** * @brief Call regex to split each input string into tokens. * @@ -176,8 +149,15 @@ rmm::device_uvector generate_tokens(column_device_view const& // the last offset entry is the total number of tokens to be generated auto const total_tokens = cudf::detail::get_value(offsets, strings_count, stream); - return regex_dispatcher( - d_prog, generate_dispatch_fn{d_prog}, d_strings, total_tokens, direction, d_offsets, stream); + rmm::device_uvector tokens(total_tokens, stream); + if (total_tokens == 0) { return tokens; } + + launch_for_each_kernel(token_reader_fn{d_strings, direction, d_offsets, tokens.data()}, + d_prog, + d_strings.size(), + stream); + + return tokens; } /** @@ -221,7 +201,7 @@ std::unique_ptr
split_re(strings_column_view const& input, } // create the regex device prog from the given pattern - auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); + auto d_prog = reprog_device::create(pattern, stream); auto d_strings = column_device_view::create(input.parent(), stream); // count the number of delimiters matched in each string @@ -283,7 +263,7 @@ std::unique_ptr split_record_re(strings_column_view const& input, auto const strings_count = input.size(); // create the regex device prog from the given pattern - auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); + auto d_prog = reprog_device::create(pattern, stream); auto d_strings = column_device_view::create(input.parent(), stream); // count the number of delimiters matched in each string