From f0c62cb591412a36824a76efc64302957c0261d9 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Fri, 23 Jun 2023 17:28:50 -0400 Subject: [PATCH] Performance improvement for some libcudf regex functions for long strings (#13322) Changes the internal regex logic to minimize character counting to help performance with longer strings. The improvement applies mainly to libcudf regex functions that return strings (i.e. extract, replace, split). The changes here also improve the internal device APIs for clarity to improve maintenance. The most significant change makes the position variables input-only and returning an optional pair to indicate a successful match. There are some more optimizations that are possible here where character positions are passed back and forth that could be replaced with byte positions to further reduce counting. Initial measurements showed this noticeably slowed down small strings so more analysis is required before continuing this optimization. Reference: https://github.com/rapidsai/cudf/pull/13480 ### More Detail First, there is a change to some internal regex function signatures. Notable the `reprog_device::find()` and `reprog_device::extract()` member functions declared in `cpp/src/strings/regex/regex.cuh` that are used by all the libcudf regex functions. The in/out parameters are now input-only parameters (pass by value) and the return is an optional pair that includes the match result. Also, the `begin` parameter is now an iterator and the `end` parameter now has a default. This change requires updating all the definitions and uses of the `find` and `extract` member functions. Using an iterator as the `begin` parameter allows for some optimizations in the calling code to minimize character counting that may be needed for processing multi-byte UTF-8 characters. Rather than using the `cudf::string_view::byte_offset()` member function to convert character positions to byte positions, an iterator can be incremented as we traverse through the string which helps reduce some character counting. So the changes here involve removing some calls to `byte_offset()` and incrementing (really moving) iterators with a pattern like `itr += (new_pos - itr.position());` There is another PR #13428 to make a `move_to` iterator member function. It is possible to reduce the character counting even more as mentioned above but further optimization requires some deeper analysis. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Mark Harris (https://github.com/harrism) - MithunR (https://github.com/mythrocks) URL: https://github.com/rapidsai/cudf/pull/13322 --- cpp/include/cudf/strings/detail/utilities.cuh | 14 ++-- cpp/src/strings/contains.cu | 7 +- cpp/src/strings/count_matches.cu | 12 +-- cpp/src/strings/extract/extract.cu | 23 ++--- cpp/src/strings/extract/extract_all.cu | 38 +++++---- cpp/src/strings/regex/regex.cuh | 75 ++++++++++------- cpp/src/strings/regex/regex.inl | 58 ++++++------- cpp/src/strings/regex/regex_program_impl.h | 3 +- cpp/src/strings/replace/backref_re.cuh | 52 ++++++------ cpp/src/strings/replace/multi_re.cu | 84 ++++++++++--------- cpp/src/strings/replace/replace_re.cu | 66 +++++++-------- cpp/src/strings/search/findall.cu | 15 ++-- cpp/src/strings/split/split_re.cu | 27 +++--- cpp/tests/strings/extract_tests.cpp | 6 +- cpp/tests/strings/findall_tests.cpp | 4 +- 15 files changed, 257 insertions(+), 227 deletions(-) diff --git a/cpp/include/cudf/strings/detail/utilities.cuh b/cpp/include/cudf/strings/detail/utilities.cuh index 64f5d3f0450..5c719cd25d2 100644 --- a/cpp/include/cudf/strings/detail/utilities.cuh +++ b/cpp/include/cudf/strings/detail/utilities.cuh @@ -18,6 +18,9 @@ #include #include +#include +#include + #include #include @@ -29,14 +32,15 @@ namespace detail { * @brief Copies input string data into a buffer and increments the pointer by the number of bytes * copied. * - * @param buffer Device buffer to copy to. - * @param input Data to copy from. - * @param bytes Number of bytes to copy. - * @return Pointer to the end of the output buffer after the copy. + * @param buffer Device buffer to copy to + * @param input Data to copy from + * @param bytes Number of bytes to copy + * @return Pointer to the end of the output buffer after the copy */ __device__ inline char* copy_and_increment(char* buffer, char const* input, size_type bytes) { - memcpy(buffer, input, bytes); + // this can be slightly faster than memcpy + thrust::copy_n(thrust::seq, input, bytes, buffer); return buffer + bytes; } diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index 44b3faeb38a..22534870409 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -50,10 +50,9 @@ struct contains_fn { if (d_strings.is_null(idx)) return false; auto const d_str = d_strings.element(idx); - size_type begin = 0; - size_type end = beginning_only ? 1 // match only the beginning of the string; - : -1; // match anywhere in the string - return static_cast(prog.find(thread_idx, d_str, begin, end)); + size_type end = beginning_only ? 1 // match only the beginning of the string; + : -1; // match anywhere in the string + return prog.find(thread_idx, d_str, d_str.begin(), end).has_value(); } }; diff --git a/cpp/src/strings/count_matches.cu b/cpp/src/strings/count_matches.cu index 1fde3a54089..6de5d43dc94 100644 --- a/cpp/src/strings/count_matches.cu +++ b/cpp/src/strings/count_matches.cu @@ -41,12 +41,14 @@ struct count_fn { auto const nchars = d_str.length(); int32_t count = 0; - size_type begin = 0; - size_type end = -1; - while ((begin <= nchars) && (prog.find(thread_idx, d_str, begin, end) > 0)) { + auto itr = d_str.begin(); + while (itr.position() <= nchars) { + auto result = prog.find(thread_idx, d_str, itr); + if (!result) { break; } ++count; - begin = end + (begin == end); - end = -1; + // increment the iterator is faster than creating a new one + // +1 if the match was on a virtual position (e.g. word boundary) + itr += (result->second - itr.position()) + (result->first == result->second); } return count; } diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index ccfc007e7ed..532053e750e 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -61,18 +61,19 @@ struct extract_fn { if (d_strings.is_valid(idx)) { auto const d_str = d_strings.element(idx); - - size_type begin = 0; - size_type end = -1; // handles empty strings automatically - if (d_prog.find(prog_idx, d_str, begin, end) > 0) { + auto const match = d_prog.find(prog_idx, d_str, d_str.begin()); + if (match) { + auto const itr = d_str.begin() + match->first; + auto last_pos = itr; for (auto col_idx = 0; col_idx < groups; ++col_idx) { - auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, col_idx); - d_output[col_idx] = [&] { - if (!extracted) return string_index_pair{nullptr, 0}; - auto const offset = d_str.byte_offset((*extracted).first); - return string_index_pair{d_str.data() + offset, - d_str.byte_offset((*extracted).second) - offset}; - }(); + auto const extracted = d_prog.extract(prog_idx, d_str, itr, match->second, col_idx); + if (extracted) { + auto const d_extracted = string_from_match(*extracted, d_str, last_pos); + d_output[col_idx] = string_index_pair{d_extracted.data(), d_extracted.size_bytes()}; + last_pos += (extracted->second - last_pos.position()); + } else { + d_output[col_idx] = string_index_pair{nullptr, 0}; + } } return; } diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index 1252e79be90..fcd05ee9dc6 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -59,32 +59,36 @@ struct extract_fn { { if (d_strings.is_null(idx)) { return; } + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); + auto const groups = d_prog.group_counts(); auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - auto const d_str = d_strings.element(idx); - auto const nchars = d_str.length(); + auto itr = d_str.begin(); - size_type begin = 0; - size_type end = nchars; - // match the regex - while ((begin < end) && d_prog.find(prog_idx, d_str, begin, end) > 0) { + while (itr.position() < nchars) { + // first, match the regex + auto const match = d_prog.find(prog_idx, d_str, itr); + if (!match) { break; } + itr += (match->first - itr.position()); // position to beginning of the match + auto last_pos = itr; // extract each group into the output for (auto group_idx = 0; group_idx < groups; ++group_idx) { // result is an optional containing the bounds of the extracted string at group_idx - auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, group_idx); - - d_output[group_idx + output_idx] = [&] { - if (!extracted) { return string_index_pair{nullptr, 0}; } - auto const start_offset = d_str.byte_offset(extracted->first); - auto const end_offset = d_str.byte_offset(extracted->second); - return string_index_pair{d_str.data() + start_offset, end_offset - start_offset}; - }(); + auto const extracted = d_prog.extract(prog_idx, d_str, itr, match->second, group_idx); + if (extracted) { + auto const d_result = string_from_match(*extracted, d_str, last_pos); + d_output[group_idx + output_idx] = + string_index_pair{d_result.data(), d_result.size_bytes()}; + } else { + d_output[group_idx + output_idx] = string_index_pair{nullptr, 0}; + } + last_pos += (extracted->second - last_pos.position()); } - // continue to next match - begin = end; - end = nchars; + // point to the end of this match to start the next match + itr += (match->second - itr.position()); output_idx += groups; } } diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 4d18af69b9c..19d82380350 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -30,9 +31,6 @@ #include namespace cudf { - -class string_view; - namespace strings { namespace detail { @@ -184,36 +182,33 @@ class reprog_device { * * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. - * @param[in,out] begin Position index to begin the search. If found, returns the position found - * in the string. - * @param[in,out] end Position index to end the search. If found, returns the last position - * matching in the string. - * @return Returns 0 if no match is found. + * @param begin Position to begin the search within `d_str`. + * @param end Character position index to end the search within `d_str`. + * Specify -1 to match any virtual positions past the end of the string. + * @return If match found, returns character positions of the matches. */ - __device__ inline int32_t find(int32_t const thread_idx, - string_view const d_str, - cudf::size_type& begin, - cudf::size_type& end) const; + __device__ inline match_result find(int32_t const thread_idx, + string_view const d_str, + string_view::const_iterator begin, + cudf::size_type end = -1) const; /** * @brief Does an extract evaluation using the compiled expression on the given string. * - * This will find a specific match within the string when more than match occurs. + * This will find a specific capture group within the string. * The find() function should be called first to locate the begin/end bounds of the * the matched section. * * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. - * @param begin Position index to begin the search. If found, returns the position found - * in the string. - * @param end Position index to end the search. If found, returns the last position - * matching in the string. + * @param begin Position to begin the search within `d_str`. + * @param end Character position index to end the search within `d_str`. * @param group_id The specific group to return its matching position values. * @return If valid, returns the character position of the matched group in the given string, */ __device__ inline match_result extract(int32_t const thread_idx, string_view const d_str, - cudf::size_type begin, + string_view::const_iterator begin, cudf::size_type end, cudf::size_type const group_id) const; @@ -241,20 +236,20 @@ class reprog_device { /** * @brief Executes the regex pattern on the given string. */ - __device__ inline int32_t regexec(string_view const d_str, - reljunk jnk, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id = 0) const; + __device__ inline match_result regexec(string_view const d_str, + reljunk jnk, + string_view::const_iterator begin, + cudf::size_type end, + cudf::size_type const group_id = 0) const; /** * @brief Utility wrapper to setup state memory structures for calling regexec */ - __device__ inline int32_t call_regexec(int32_t const thread_idx, - string_view const d_str, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id = 0) const; + __device__ inline match_result call_regexec(int32_t const thread_idx, + string_view const d_str, + string_view::const_iterator begin, + cudf::size_type end, + cudf::size_type const group_id = 0) const; reprog_device(reprog const&); @@ -285,6 +280,30 @@ class reprog_device { */ std::size_t compute_working_memory_size(int32_t num_threads, int32_t insts_count); +/** + * @brief Converts a match_pair from character positions to byte positions + */ +__device__ __forceinline__ match_pair match_positions_to_bytes(match_pair const result, + string_view d_str, + string_view::const_iterator last) +{ + if (d_str.length() == d_str.size_bytes()) { return result; } + auto const begin = (last + (result.first - last.position())).byte_offset(); + auto const end = (last + (result.second - last.position())).byte_offset(); + return {begin, end}; +} + +/** + * @brief Creates a string_view from a match result + */ +__device__ __forceinline__ string_view string_from_match(match_pair const result, + string_view d_str, + string_view::const_iterator last) +{ + auto const [begin, end] = match_positions_to_bytes(result, d_str, last); + return string_view(d_str.data() + begin, end - begin); +} + } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index d25a0888f32..c5205ae7789 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,6 @@ #include #include -#include -#include - -#include namespace cudf { namespace strings { @@ -235,21 +231,19 @@ __device__ __forceinline__ reprog_device reprog_device::load(reprog_device const * @param group_id Index of the group to match in a multi-group regex pattern. * @return >0 if match found */ -__device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr, - reljunk jnk, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id) const +__device__ __forceinline__ match_result reprog_device::regexec(string_view const dstr, + reljunk jnk, + string_view::const_iterator itr, + cudf::size_type end, + cudf::size_type const group_id) const { int32_t match = 0; + auto begin = itr.position(); auto pos = begin; auto eos = end; - char_utf8 c = 0; auto checkstart = jnk.starttype != 0; auto last_character = false; - string_view::const_iterator itr = string_view::const_iterator(dstr, pos); - jnk.list1->reset(); do { // fast check for first CHAR or BOL @@ -258,12 +252,12 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr switch (jnk.starttype) { case BOL: if (pos == 0) break; - if (jnk.startchar != '^') { return match; } + if (jnk.startchar != '^') { return thrust::nullopt; } --pos; startchar = static_cast('\n'); case CHAR: { auto const fidx = dstr.find(startchar, pos); - if (fidx == string_view::npos) { return match; } + if (fidx == string_view::npos) { return thrust::nullopt; } pos = fidx + (jnk.starttype == BOL); break; } @@ -279,7 +273,7 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr last_character = itr.byte_offset() >= dstr.size_bytes(); - c = last_character ? 0 : *itr; + char_utf8 const c = last_character ? 0 : *itr; // expand the non-character types like: LBRA, RBRA, BOL, EOL, BOW, NBOW, and OR bool expanded = false; @@ -394,35 +388,33 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr checkstart = jnk.list1->get_size() == 0; } while (!last_character && (!checkstart || !match)); - return match; + return match ? match_result({begin, end}) : thrust::nullopt; } -__device__ __forceinline__ int32_t reprog_device::find(int32_t const thread_idx, - string_view const dstr, - cudf::size_type& begin, - cudf::size_type& end) const +__device__ __forceinline__ match_result reprog_device::find(int32_t const thread_idx, + string_view const dstr, + string_view::const_iterator begin, + cudf::size_type end) const { - auto const rtn = call_regexec(thread_idx, dstr, begin, end); - if (rtn <= 0) begin = end = -1; - return rtn; + return call_regexec(thread_idx, dstr, begin, end); } __device__ __forceinline__ match_result reprog_device::extract(int32_t const thread_idx, string_view const dstr, - cudf::size_type begin, + string_view::const_iterator begin, cudf::size_type end, cudf::size_type const group_id) const { - end = begin + 1; - return call_regexec(thread_idx, dstr, begin, end, group_id + 1) > 0 ? match_result({begin, end}) - : thrust::nullopt; + end = begin.position() + 1; + return call_regexec(thread_idx, dstr, begin, end, group_id + 1); } -__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t const thread_idx, - string_view const dstr, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id) const +__device__ __forceinline__ match_result +reprog_device::call_regexec(int32_t const thread_idx, + string_view const dstr, + string_view::const_iterator begin, + cudf::size_type end, + cudf::size_type const group_id) const { auto gp_ptr = reinterpret_cast(_buffer); relist list1(static_cast(_max_insts), _thread_count, gp_ptr, thread_idx); diff --git a/cpp/src/strings/regex/regex_program_impl.h b/cpp/src/strings/regex/regex_program_impl.h index eede2225bce..74cc1902739 100644 --- a/cpp/src/strings/regex/regex_program_impl.h +++ b/cpp/src/strings/regex/regex_program_impl.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once #include "regcomp.h" #include "regex.cuh" diff --git a/cpp/src/strings/replace/backref_re.cuh b/cpp/src/strings/replace/backref_re.cuh index a5f3ace2141..aeaea40358f 100644 --- a/cpp/src/strings/replace/backref_re.cuh +++ b/cpp/src/strings/replace/backref_re.cuh @@ -45,7 +45,7 @@ struct backrefs_fn { string_view const d_repl; // string replacement template Iterator backrefs_begin; Iterator backrefs_end; - int32_t* d_offsets{}; + size_type* d_offsets{}; char* d_chars{}; __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) @@ -59,23 +59,27 @@ struct backrefs_fn { auto const nchars = d_str.length(); // number of characters in input string auto nbytes = d_str.size_bytes(); // number of bytes for the output string auto out_ptr = d_chars ? (d_chars + d_offsets[idx]) : nullptr; - size_type lpos = 0; // last byte position processed in d_str - size_type begin = 0; // first character position matching regex - size_type end = -1; // match through the end of the string + auto itr = d_str.begin(); + auto last_pos = itr; // copy input to output replacing strings as we go - while ((begin <= nchars) && - (prog.find(prog_idx, d_str, begin, end) > 0)) // inits the begin/end vars + while (itr.position() <= nchars) // inits the begin/end vars { - auto spos = d_str.byte_offset(begin); // get offset for the - auto epos = d_str.byte_offset(end); // character position values; - nbytes += d_repl.size_bytes() - (epos - spos); // compute the output size + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } + + auto const [start_pos, end_pos] = match_positions_to_bytes(*match, d_str, itr); + nbytes += d_repl.size_bytes() - (end_pos - start_pos); // compute the output size // copy the string data before the matched section - if (out_ptr) { out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); } + if (out_ptr) { + out_ptr = copy_and_increment( + out_ptr, in_ptr + last_pos.byte_offset(), start_pos - last_pos.byte_offset()); + } size_type lpos_template = 0; // last end pos of replace template auto const repl_ptr = d_repl.data(); // replace template pattern + itr += (match->first - itr.position()); thrust::for_each( thrust::seq, backrefs_begin, backrefs_end, [&] __device__(backref_type backref) { if (out_ptr) { @@ -84,17 +88,13 @@ struct backrefs_fn { lpos_template += copy_length; } // extract the specific group's string for this backref's index - auto extracted = prog.extract(prog_idx, d_str, begin, end, backref.first - 1); - if (!extracted || (extracted.value().second < extracted.value().first)) { + auto extracted = prog.extract(prog_idx, d_str, itr, match->second, backref.first - 1); + if (!extracted || (extracted->second < extracted->first)) { return; // no value for this backref number; that is ok } - auto spos_extract = d_str.byte_offset(extracted.value().first); // convert - auto epos_extract = d_str.byte_offset(extracted.value().second); // to bytes - nbytes += epos_extract - spos_extract; - if (out_ptr) { - out_ptr = - copy_and_increment(out_ptr, in_ptr + spos_extract, (epos_extract - spos_extract)); - } + auto const d_str_ex = string_from_match(*extracted, d_str, itr); + nbytes += d_str_ex.size_bytes(); + if (out_ptr) { out_ptr = copy_string(out_ptr, d_str_ex); } }); // copy remainder of template @@ -104,16 +104,16 @@ struct backrefs_fn { } // setup to match the next section - lpos = epos; - begin = end + (begin == end); - end = -1; + last_pos += (match->second - last_pos.position()); + itr = last_pos + (match->first == match->second); } // finally, copy remainder of input string - if (out_ptr && (lpos < d_str.size_bytes())) { - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - } else if (!out_ptr) { - d_offsets[idx] = static_cast(nbytes); + if (out_ptr) { + thrust::copy_n( + thrust::seq, in_ptr + itr.byte_offset(), d_str.size_bytes() - itr.byte_offset(), out_ptr); + } else { + d_offsets[idx] = nbytes; } } }; diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index b554d0a815c..867b443c036 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -55,7 +55,7 @@ struct replace_multi_regex_fn { device_span progs; // array of regex progs found_range* d_found_ranges; // working array matched (begin,end) values column_device_view const d_repls; // replacement strings - int32_t* d_offsets{}; + size_type* d_offsets{}; char* d_chars{}; __device__ void operator()(size_type idx) @@ -67,61 +67,69 @@ struct replace_multi_regex_fn { auto const number_of_patterns = static_cast(progs.size()); - auto const d_str = d_strings.element(idx); - auto const nchars = d_str.length(); // number of characters in input string - auto nbytes = d_str.size_bytes(); // number of bytes in input string - auto in_ptr = d_str.data(); // input pointer - auto out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); // number of characters in input string + auto nbytes = d_str.size_bytes(); // number of bytes in input string + auto in_ptr = d_str.data(); // input pointer + auto out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; + auto itr = d_str.begin(); + auto last_pos = itr; + found_range* d_ranges = d_found_ranges + (idx * number_of_patterns); - size_type lpos = 0; - size_type ch_pos = 0; + // initialize the working ranges memory to -1's thrust::fill(thrust::seq, d_ranges, d_ranges + number_of_patterns, found_range{-1, 1}); + // process string one character at a time - while (ch_pos < nchars) { + while (itr.position() < nchars) { // this minimizes the regex-find calls by only calling it for stale patterns // -- those that have not previously matched up to this point (ch_pos) for (size_type ptn_idx = 0; ptn_idx < number_of_patterns; ++ptn_idx) { - if (d_ranges[ptn_idx].first >= ch_pos) // previously matched here - continue; // or later in the string + if (d_ranges[ptn_idx].first >= itr.position()) { // previously matched here + continue; // or later in the string + } reprog_device prog = progs[ptn_idx]; - auto begin = ch_pos; - auto end = nchars; - if (!prog.is_empty() && prog.find(idx, d_str, begin, end) > 0) - d_ranges[ptn_idx] = found_range{begin, end}; // found a match - else - d_ranges[ptn_idx] = found_range{nchars, nchars}; // this pattern is done + auto const result = !prog.is_empty() ? prog.find(idx, d_str, itr) : thrust::nullopt; + d_ranges[ptn_idx] = + result ? found_range{result->first, result->second} : found_range{nchars, nchars}; } // all the ranges have been updated from each regex match; // look for any that match at this character position (ch_pos) - auto itr = - thrust::find_if(thrust::seq, d_ranges, d_ranges + number_of_patterns, [ch_pos](auto range) { - return range.first == ch_pos; - }); - if (itr != d_ranges + number_of_patterns) { + auto const ptn_itr = + thrust::find_if(thrust::seq, + d_ranges, + d_ranges + number_of_patterns, + [ch_pos = itr.position()](auto range) { return range.first == ch_pos; }); + if (ptn_itr != d_ranges + number_of_patterns) { // match found, compute and replace the string in the output - size_type ptn_idx = static_cast(itr - d_ranges); - size_type begin = d_ranges[ptn_idx].first; - size_type end = d_ranges[ptn_idx].second; - string_view d_repl = d_repls.size() > 1 ? d_repls.element(ptn_idx) - : d_repls.element(0); - auto spos = d_str.byte_offset(begin); - auto epos = d_str.byte_offset(end); - nbytes += d_repl.size_bytes() - (epos - spos); + auto const ptn_idx = static_cast(thrust::distance(d_ranges, ptn_itr)); + + auto d_repl = d_repls.size() > 1 ? d_repls.element(ptn_idx) + : d_repls.element(0); + + auto const d_range = d_ranges[ptn_idx]; + auto const [start_pos, end_pos] = + match_positions_to_bytes({d_range.first, d_range.second}, d_str, last_pos); + nbytes += d_repl.size_bytes() - (end_pos - start_pos); if (out_ptr) { // copy unmodified content plus new replacement string - out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); + out_ptr = copy_and_increment( + out_ptr, in_ptr + last_pos.byte_offset(), start_pos - last_pos.byte_offset()); out_ptr = copy_string(out_ptr, d_repl); - lpos = epos; } - ch_pos = end - 1; + last_pos += (d_range.second - last_pos.position()); + itr = last_pos - 1; } - ++ch_pos; + ++itr; + } + if (out_ptr) { // copy the remainder + thrust::copy_n(thrust::seq, + in_ptr + last_pos.byte_offset(), + d_str.size_bytes() - last_pos.byte_offset(), + out_ptr); + } else { + d_offsets[idx] = nbytes; } - if (out_ptr) // copy the remainder - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - else - d_offsets[idx] = static_cast(nbytes); } }; diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index c334d2b2013..460074a5296 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -42,7 +42,7 @@ struct replace_regex_fn { column_device_view const d_strings; string_view const d_repl; size_type const maxrepl; - int32_t* d_offsets{}; + size_type* d_offsets{}; char* d_chars{}; __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) @@ -54,46 +54,42 @@ struct replace_regex_fn { auto const d_str = d_strings.element(idx); auto const nchars = d_str.length(); - auto nbytes = d_str.size_bytes(); // number of bytes in input string - auto mxn = maxrepl < 0 ? nchars + 1 : maxrepl; // max possible replaces for this string - auto in_ptr = d_str.data(); // input pointer (i) - auto out_ptr = d_chars ? d_chars + d_offsets[idx] // output pointer (o) - : nullptr; - size_type last_pos = 0; - size_type begin = 0; // these are for calling prog.find - size_type end = -1; // matches final word-boundary if at the end of the string + auto nbytes = d_str.size_bytes(); // number of bytes in input string + auto mxn = maxrepl < 0 ? nchars + 1 : maxrepl; // max possible replaces for this string + auto in_ptr = d_str.data(); // input pointer (i) + auto out_ptr = d_chars ? d_chars + d_offsets[idx] // output pointer (o) + : nullptr; + auto itr = d_str.begin(); + auto last_pos = itr; // copy input to output replacing strings as we go - while (mxn-- > 0 && begin <= nchars) { // maximum number of replaces - - if (prog.is_empty() || prog.find(prog_idx, d_str, begin, end) <= 0) { - break; // no more matches - } - - auto const start_pos = d_str.byte_offset(begin); // get offset for these - auto const end_pos = d_str.byte_offset(end); // character position values - nbytes += d_repl.size_bytes() - (end_pos - start_pos); // and compute new size - - if (out_ptr) { // replace: - // i:bbbbsssseeee - out_ptr = copy_and_increment(out_ptr, // ^ - in_ptr + last_pos, // o:bbbb - start_pos - last_pos); // ^ - out_ptr = copy_string(out_ptr, d_repl); // o:bbbbrrrrrr - // out_ptr ---^ - last_pos = end_pos; // i:bbbbsssseeee - } // in_ptr --^ - - begin = end + (begin == end); - end = -1; + while (mxn-- > 0 && itr.position() <= nchars && !prog.is_empty()) { + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } // no more matches + + auto const [start_pos, end_pos] = match_positions_to_bytes(*match, d_str, last_pos); + nbytes += d_repl.size_bytes() - (end_pos - start_pos); // add new size + + if (out_ptr) { // replace: + // i:bbbbsssseeee + out_ptr = copy_and_increment(out_ptr, // ^ + in_ptr + last_pos.byte_offset(), // o:bbbb + start_pos - last_pos.byte_offset()); // ^ + out_ptr = copy_string(out_ptr, d_repl); // o:bbbbrrrrrr + } // out_ptr ---^ + last_pos += (match->second - last_pos.position()); // i:bbbbsssseeee + // in_ptr --^ + + itr = last_pos + (match->first == match->second); } if (out_ptr) { - memcpy(out_ptr, // copy the remainder - in_ptr + last_pos, // o:bbbbrrrrrreeee - d_str.size_bytes() - last_pos); // ^ ^ + thrust::copy_n(thrust::seq, // copy the remainder + in_ptr + last_pos.byte_offset(), // o:bbbbrrrrrreeee + d_str.size_bytes() - last_pos.byte_offset(), // ^ ^ + out_ptr); } else { - d_offsets[idx] = static_cast(nbytes); + d_offsets[idx] = nbytes; } } }; diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 0c8359928a5..596fbb39d15 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -62,16 +62,15 @@ struct findall_fn { auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - size_type begin = 0; - size_type end = nchars; - while ((begin < end) && (prog.find(prog_idx, d_str, begin, end) > 0)) { - auto const spos = d_str.byte_offset(begin); // convert - auto const epos = d_str.byte_offset(end); // to bytes + auto itr = d_str.begin(); + while (itr.position() < nchars) { + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } - d_output[output_idx++] = string_index_pair{d_str.data() + spos, (epos - spos)}; + auto const d_result = string_from_match(*match, d_str, itr); + d_output[output_idx++] = string_index_pair{d_result.data(), d_result.size_bytes()}; - begin = end + (begin == end); - end = nchars; + itr += (match->second - itr.position()); } } }; diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index 25fe4d00336..f0829eb08ba 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -66,20 +66,25 @@ struct token_reader_fn { __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); auto const token_offset = d_token_offsets[idx]; auto const token_count = d_token_offsets[idx + 1] - token_offset; auto const d_result = d_tokens + token_offset; // store tokens here size_type token_idx = 0; - size_type begin = 0; // characters - size_type end = -1; - size_type last_pos = 0; // bytes - while (prog.find(prog_idx, d_str, begin, end) > 0) { + auto itr = d_str.begin(); + auto last_pos = itr; + while (itr.position() <= nchars) { + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } + + auto const [start_pos, end_pos] = match_positions_to_bytes(*match, d_str, last_pos); + // get the token (characters just before this match) - auto const token = - string_index_pair{d_str.data() + last_pos, d_str.byte_offset(begin) - last_pos}; + auto const token = string_index_pair{d_str.data() + last_pos.byte_offset(), + start_pos - last_pos.byte_offset()}; // store it if we have space if (token_idx < token_count - 1) { d_result[token_idx++] = token; @@ -91,13 +96,13 @@ struct token_reader_fn { d_result[token_idx - 1] = token; } // setup for next match - last_pos = d_str.byte_offset(end); - begin = end + (begin == end); - end = -1; + last_pos += (match->second - last_pos.position()); + itr = last_pos + (match->first == match->second); } // set the last token to the remainder of the string - d_result[token_idx] = string_index_pair{d_str.data() + last_pos, d_str.size_bytes() - last_pos}; + d_result[token_idx] = string_index_pair{d_str.data() + last_pos.byte_offset(), + d_str.size_bytes() - last_pos.byte_offset()}; if (direction == split_direction::BACKWARD) { // update first entry -- this happens when max_tokens is hit before the end of the string diff --git a/cpp/tests/strings/extract_tests.cpp b/cpp/tests/strings/extract_tests.cpp index 312341d6559..70112f7ca75 100644 --- a/cpp/tests/strings/extract_tests.cpp +++ b/cpp/tests/strings/extract_tests.cpp @@ -226,7 +226,7 @@ TEST_F(StringsExtractTests, EmptyExtractTest) TEST_F(StringsExtractTests, ExtractAllTest) { std::vector h_input( - {"123 banana 7 eleven", "41 apple", "6 pear 0 pair", nullptr, "", "bees", "4 pare"}); + {"123 banana 7 eleven", "41 apple", "6 péar 0 pair", nullptr, "", "bees", "4 paré"}); auto validity = thrust::make_transform_iterator(h_input.begin(), [](auto str) { return str != nullptr; }); cudf::test::strings_column_wrapper input(h_input.begin(), h_input.end(), validity); @@ -238,11 +238,11 @@ TEST_F(StringsExtractTests, ExtractAllTest) using LCW = cudf::test::lists_column_wrapper; LCW expected({LCW{"123", "banana", "7", "eleven"}, LCW{"41", "apple"}, - LCW{"6", "pear", "0", "pair"}, + LCW{"6", "péar", "0", "pair"}, LCW{}, LCW{}, LCW{}, - LCW{"4", "pare"}}, + LCW{"4", "paré"}}, valids); auto prog = cudf::strings::regex_program::create(pattern); auto results = cudf::strings::extract_all_record(sv, *prog); diff --git a/cpp/tests/strings/findall_tests.cpp b/cpp/tests/strings/findall_tests.cpp index c7eddb69ee7..fe27beed197 100644 --- a/cpp/tests/strings/findall_tests.cpp +++ b/cpp/tests/strings/findall_tests.cpp @@ -69,12 +69,12 @@ TEST_F(StringsFindallTests, Multiline) TEST_F(StringsFindallTests, DotAll) { - cudf::test::strings_column_wrapper input({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""}); + cudf::test::strings_column_wrapper input({"abc\nfa\nef", "fff\nabbc\nfff", "abcdéf", ""}); auto view = cudf::strings_column_view(input); auto pattern = std::string("(b.*f)"); using LCW = cudf::test::lists_column_wrapper; - LCW expected({LCW{"bc\nfa\nef"}, LCW{"bbc\nfff"}, LCW{"bcdef"}, LCW{}}); + LCW expected({LCW{"bc\nfa\nef"}, LCW{"bbc\nfff"}, LCW{"bcdéf"}, LCW{}}); auto prog = cudf::strings::regex_program::create(pattern, cudf::strings::regex_flags::DOTALL); auto results = cudf::strings::findall(view, *prog); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected);