From 39a6f2a857b8a0f51a2accdde20d8af4f8659dcd Mon Sep 17 00:00:00 2001 From: davidwendt Date: Mon, 21 Jun 2021 15:19:29 -0400 Subject: [PATCH 1/2] Fix bug in replace_with_backrefs when group has greedy quantifier --- cpp/src/strings/extract.cu | 97 +++++++++++------------ cpp/src/strings/regex/regex.cuh | 33 ++++---- cpp/src/strings/regex/regex.inl | 38 ++++----- cpp/src/strings/replace/backref_re.cu | 12 +-- cpp/src/strings/replace/backref_re.cuh | 25 +++--- cpp/tests/strings/replace_regex_tests.cpp | 17 ++++ 6 files changed, 109 insertions(+), 113 deletions(-) diff --git a/cpp/src/strings/extract.cu b/cpp/src/strings/extract.cu index 423bfff0cbc..d12f5c534a5 100644 --- a/cpp/src/strings/extract.cu +++ b/cpp/src/strings/extract.cu @@ -22,22 +22,21 @@ #include #include #include -#include #include #include #include #include -#include #include -#include - namespace cudf { namespace strings { namespace detail { namespace { + +using string_index_pair = thrust::pair; + /** * @brief This functor handles extracting strings by applying the compiled regex pattern * and creating string_index_pairs for all the substrings. @@ -49,23 +48,25 @@ template struct extract_fn { reprog_device prog; column_device_view d_strings; - cudf::detail::device_2dspan d_indices; + size_type column_index; - __device__ void operator()(size_type idx) + __device__ string_index_pair operator()(size_type idx) { - auto groups = prog.group_counts(); - auto d_output = d_indices[idx]; - if (d_strings.is_valid(idx)) { - string_view d_str = d_strings.element(idx); - int32_t begin = 0; - int32_t end = -1; // handles empty strings automatically - if ((prog.find(idx, d_str, begin, end) > 0) && - prog.extract(idx, d_str, begin, end, d_output)) { - return; + if (d_strings.is_null(idx)) return string_index_pair{nullptr, 0}; + string_view d_str = d_strings.element(idx); + string_index_pair result{nullptr, 0}; + int32_t begin = 0; + int32_t end = -1; // handles empty strings automatically + if (prog.find(idx, d_str, begin, end) > 0) { + auto extracted = prog.extract(idx, d_str, begin, end, column_index); + if (extracted) { + auto const offset = d_str.byte_offset(extracted.value().first); + // build index-pair + result = string_index_pair{d_str.data() + offset, + d_str.byte_offset(extracted.value().second) - offset}; } } - // fill output with null entries - thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0}); + return result; } }; @@ -93,41 +94,37 @@ std::unique_ptr extract( std::vector> results; auto regex_insts = d_prog.insts_counts(); - rmm::device_uvector indices(strings_count * groups, stream); - cudf::detail::device_2dspan d_indices(indices.data(), strings_count, groups); - - if (regex_insts <= RX_SMALL_INSTS) { - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - extract_fn{d_prog, d_strings, d_indices}); - } else if (regex_insts <= RX_MEDIUM_INSTS) { - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - extract_fn{d_prog, d_strings, d_indices}); - } else if (regex_insts <= RX_LARGE_INSTS) { - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - extract_fn{d_prog, d_strings, d_indices}); - } else { // supports any number of instructions - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(strings_count), - extract_fn{d_prog, d_strings, d_indices}); - } - for (int32_t column_index = 0; column_index < groups; ++column_index) { - auto indices_itr = thrust::make_permutation_iterator( - indices.begin(), - thrust::make_transform_iterator(thrust::make_counting_iterator(0), - [column_index, groups] __device__(size_type idx) { - return (idx * groups) + column_index; - })); - results.emplace_back(make_strings_column(indices_itr, indices_itr + strings_count, stream, mr)); - } + rmm::device_uvector indices(strings_count, stream); + + if (regex_insts <= RX_SMALL_INSTS) { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + indices.begin(), + extract_fn{d_prog, d_strings, column_index}); + } else if (regex_insts <= RX_MEDIUM_INSTS) { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + indices.begin(), + extract_fn{d_prog, d_strings, column_index}); + } else if (regex_insts <= RX_LARGE_INSTS) { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + indices.begin(), + extract_fn{d_prog, d_strings, column_index}); + } else { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + indices.begin(), + extract_fn{d_prog, d_strings, column_index}); + } + results.emplace_back(make_strings_column(indices, stream, mr)); + } return std::make_unique
(std::move(results)); } diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 66b90abc393..564f742b2cd 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -18,10 +18,10 @@ #include #include -#include #include +#include #include #include @@ -38,7 +38,8 @@ struct reljunk; struct reinst; class reprog; -using string_index_pair = thrust::pair; +using match_pair = thrust::pair; +using match_result = thrust::optional; constexpr int32_t RX_STACK_SMALL = 112; ///< fastest stack size constexpr int32_t RX_STACK_MEDIUM = 1104; ///< faster stack size @@ -176,15 +177,15 @@ class reprog_device { * in the string. * @param end Position index to end the search. If found, returns the last position * matching in the string. - * @param indices All extracted groups - * @return Returns true if successful. + * @param group_id The specific group to return its matching position values. + * @return If valid, returns the character position of the matched group in the given string, */ template - __device__ inline bool extract(int32_t idx, - string_view const& d_str, - int32_t begin, - int32_t end, - device_span indices); + __device__ inline match_result extract(cudf::size_type idx, + string_view const& d_str, + cudf::size_type begin, + cudf::size_type end, + cudf::size_type group_id); private: int32_t _startinst_id, _num_capturing_groups; @@ -198,21 +199,15 @@ class reprog_device { /** * @brief Executes the regex pattern on the given string. */ - __device__ inline int32_t regexec(string_view const& d_str, - reljunk& jnk, - int32_t& begin, - int32_t& end, - string_index_pair* indices = nullptr); + __device__ inline int32_t regexec( + string_view const& d_str, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id = 0); /** * @brief Utility wrapper to setup state memory structures for calling regexec */ template - __device__ inline int32_t call_regexec(int32_t idx, - string_view const& d_str, - int32_t& begin, - int32_t& end, - string_index_pair* indices = nullptr); + __device__ inline int32_t call_regexec( + int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t group_id = 0); reprog_device(reprog&); // must use create() }; diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index 5c9f389827f..eddda3fe0eb 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -194,7 +194,7 @@ __device__ inline int32_t* reprog_device::startinst_ids() const { return _starti * @return >0 if match found */ __device__ inline int32_t reprog_device::regexec( - string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, string_index_pair* indices) + string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id) { int32_t match = 0; auto checkstart = jnk.starttype; @@ -231,7 +231,7 @@ __device__ inline int32_t reprog_device::regexec( if (((eos < 0) || (pos < eos)) && match == 0) { int32_t i = 0; auto ids = startinst_ids(); - while (ids[i] >= 0) jnk.list1->activate(ids[i++], (indices == nullptr ? pos : -1), -1); + while (ids[i] >= 0) jnk.list1->activate(ids[i++], (group_id == 0 ? pos : -1), -1); } c = static_cast(pos >= txtlen ? 0 : *itr); @@ -256,20 +256,14 @@ __device__ inline int32_t reprog_device::regexec( case NCCLASS: case END: id_activate = inst_id; break; case LBRA: - if (indices && inst->u1.subid == _num_capturing_groups) range.x = pos; + if (inst->u1.subid == group_id) range.x = pos; id_activate = inst->u2.next_id; expanded = true; - if (indices) { indices[inst->u1.subid - 1].first = dstr.data() + itr.byte_offset(); } break; case RBRA: - if (indices && inst->u1.subid == _num_capturing_groups) range.y = pos; + if (inst->u1.subid == group_id) range.y = pos; id_activate = inst->u2.next_id; expanded = true; - if (indices) { - auto const ptr_offset = indices[inst->u1.subid - 1].first - dstr.data(); - indices[inst->u1.subid - 1].second = - itr.byte_offset() - static_cast(ptr_offset); - } break; case BOL: if ((pos == 0) || @@ -352,7 +346,7 @@ __device__ inline int32_t reprog_device::regexec( case END: match = 1; begin = range.x; - end = indices == nullptr ? pos : range.y; + end = group_id == 0 ? pos : range.y; continue_execute = false; break; @@ -382,19 +376,21 @@ __device__ inline int32_t reprog_device::find(int32_t idx, } template -__device__ inline bool reprog_device::extract(int32_t idx, - string_view const& dstr, - int32_t begin, - int32_t end, - device_span indices) +__device__ inline match_result reprog_device::extract(cudf::size_type idx, + string_view const& dstr, + cudf::size_type begin, + cudf::size_type end, + cudf::size_type group_id) { end = begin + 1; - return call_regexec(idx, dstr, begin, end, indices.data()) > 0; + return call_regexec(idx, dstr, begin, end, group_id + 1) > 0 + ? match_result({begin, end}) + : thrust::nullopt; } template __device__ inline int32_t reprog_device::call_regexec( - int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices) + int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id) { u_char data1[stack_size], data2[stack_size]; @@ -405,12 +401,12 @@ __device__ inline int32_t reprog_device::call_regexec( relist list2(static_cast(_insts_count), data2); reljunk jnk(&list1, &list2, stype, schar); - return regexec(dstr, jnk, begin, end, indices); + return regexec(dstr, jnk, begin, end, group_id); } template <> __device__ inline int32_t reprog_device::call_regexec( - int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, string_index_pair* indices) + int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id) { auto const stype = get_inst(_startinst_id)->type; auto const schar = get_inst(_startinst_id)->u1.c; @@ -423,7 +419,7 @@ __device__ inline int32_t reprog_device::call_regexec( relist* list2 = new (listmem + relists_size) relist(static_cast(_insts_count)); reljunk jnk(list1, list2, stype, schar); - return regexec(dstr, jnk, begin, end, indices); + return regexec(dstr, jnk, begin, end, group_id); } } // namespace detail diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index 99a2f4f78c7..462efedffe5 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -111,35 +111,31 @@ std::unique_ptr replace_with_backrefs( // create child columns auto [offsets, chars] = [&] { - rmm::device_uvector indices(strings.size() * d_prog->group_counts(), stream); - cudf::detail::device_2dspan d_indices( - indices.data(), strings.size(), d_prog->group_counts()); - if (regex_insts <= RX_SMALL_INSTS) { return make_strings_children( backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices}, + *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, strings.size(), stream, mr); } else if (regex_insts <= RX_MEDIUM_INSTS) { return make_strings_children( backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices}, + *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, strings.size(), stream, mr); } else if (regex_insts <= RX_LARGE_INSTS) { return make_strings_children( backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices}, + *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, strings.size(), stream, mr); } else { return make_strings_children( backrefs_fn{ - *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end(), d_indices}, + *d_strings, *d_prog, d_repl_template, backrefs.begin(), backrefs.end()}, strings.size(), stream, mr); diff --git a/cpp/src/strings/replace/backref_re.cuh b/cpp/src/strings/replace/backref_re.cuh index 0456f9a5998..9b84b21a44e 100644 --- a/cpp/src/strings/replace/backref_re.cuh +++ b/cpp/src/strings/replace/backref_re.cuh @@ -18,7 +18,6 @@ #include #include #include -#include #include @@ -43,7 +42,6 @@ struct backrefs_fn { string_view const d_repl; // string replacement template Iterator backrefs_begin; Iterator backrefs_end; - cudf::detail::device_2dspan d_indices; int32_t* d_offsets{}; char* d_chars{}; @@ -62,9 +60,6 @@ struct backrefs_fn { size_type begin = 0; // first character position matching regex size_type end = nchars; // last character position (exclusive) - // working memory for extract on this string - auto d_extracts = d_indices[idx]; - // copy input to output replacing strings as we go while (prog.find(idx, d_str, begin, end) > 0) // inits the begin/end vars { @@ -77,23 +72,23 @@ struct backrefs_fn { size_type lpos_template = 0; // last end pos of replace template auto const repl_ptr = d_repl.data(); // replace template pattern - // extracts all groups for this string into d_extracts - prog.extract(idx, d_str, begin, end, d_extracts); - thrust::for_each( thrust::seq, backrefs_begin, backrefs_end, [&] __device__(backref_type backref) { - // copy the static data at the beginning of the template if (out_ptr) { auto const copy_length = backref.second - lpos_template; out_ptr = copy_and_increment(out_ptr, repl_ptr + lpos_template, copy_length); lpos_template += copy_length; } - // retrieve the string for this backref - auto const extracted_string = d_extracts[backref.first - 1]; - nbytes += extracted_string.second; - if (out_ptr) { - out_ptr = copy_and_increment(out_ptr, extracted_string.first, extracted_string.second); - } + // extract the specific group's string for this backref's index + auto extracted = prog.extract(idx, d_str, begin, end, backref.first - 1); + if (!extracted || (extracted.value().second <= extracted.value().first)) + return; // no value for this backref number; that is ok + auto spos_extract = d_str.byte_offset(extracted.value().first); // convert + auto epos_extract = d_str.byte_offset(extracted.value().second); // to bytes + nbytes += epos_extract - spos_extract; + if (out_ptr) + out_ptr = + copy_and_increment(out_ptr, in_ptr + spos_extract, (epos_extract - spos_extract)); }); // copy remainder of template diff --git a/cpp/tests/strings/replace_regex_tests.cpp b/cpp/tests/strings/replace_regex_tests.cpp index c7543e29b0a..a2486d60051 100644 --- a/cpp/tests/strings/replace_regex_tests.cpp +++ b/cpp/tests/strings/replace_regex_tests.cpp @@ -186,6 +186,23 @@ TEST_F(StringsReplaceTests, ReplaceBackrefsRegexReversedTest) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } +TEST_F(StringsReplaceTests, BackrefWithGreedyQuantifier) +{ + cudf::test::strings_column_wrapper input( + {"

title

ABC

", "

1234567

XYZ

"}); + std::string replacement = "

\\1

\\2

"; + + auto results = cudf::strings::replace_with_backrefs( + cudf::strings_column_view(input), "

(.*)

(.*)

", replacement); + cudf::test::strings_column_wrapper expected( + {"

title

ABC

", "

1234567

XYZ

"}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + results = cudf::strings::replace_with_backrefs( + cudf::strings_column_view(input), "

([a-z\\d]+)

([A-Z]+)

", replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); +} + TEST_F(StringsReplaceTests, MediumReplaceRegex) { // This results in 95 regex instructions and falls in the 'medium' range. From 1b07d05d6f961ad4ddf29feaff1efc1d01f80742 Mon Sep 17 00:00:00 2001 From: davidwendt Date: Thu, 24 Jun 2021 11:13:42 -0400 Subject: [PATCH 2/2] add more brackets --- cpp/src/strings/replace/backref_re.cuh | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cpp/src/strings/replace/backref_re.cuh b/cpp/src/strings/replace/backref_re.cuh index 9b84b21a44e..eba5c3f1044 100644 --- a/cpp/src/strings/replace/backref_re.cuh +++ b/cpp/src/strings/replace/backref_re.cuh @@ -68,7 +68,7 @@ struct backrefs_fn { nbytes += d_repl.size_bytes() - (epos - spos); // compute the output size // copy the string data before the matched section - if (out_ptr) out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); + if (out_ptr) { out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); } size_type lpos_template = 0; // last end pos of replace template auto const repl_ptr = d_repl.data(); // replace template pattern @@ -81,20 +81,23 @@ struct backrefs_fn { } // extract the specific group's string for this backref's index auto extracted = prog.extract(idx, d_str, begin, end, backref.first - 1); - if (!extracted || (extracted.value().second <= extracted.value().first)) + if (!extracted || (extracted.value().second <= extracted.value().first)) { return; // no value for this backref number; that is ok + } auto spos_extract = d_str.byte_offset(extracted.value().first); // convert auto epos_extract = d_str.byte_offset(extracted.value().second); // to bytes nbytes += epos_extract - spos_extract; - if (out_ptr) + if (out_ptr) { out_ptr = copy_and_increment(out_ptr, in_ptr + spos_extract, (epos_extract - spos_extract)); + } }); // copy remainder of template - if (out_ptr && (lpos_template < d_repl.size_bytes())) + if (out_ptr && (lpos_template < d_repl.size_bytes())) { out_ptr = copy_and_increment( out_ptr, repl_ptr + lpos_template, d_repl.size_bytes() - lpos_template); + } // setup to match the next section lpos = epos; @@ -103,10 +106,11 @@ struct backrefs_fn { } // finally, copy remainder of input string - if (out_ptr && (lpos < d_str.size_bytes())) + if (out_ptr && (lpos < d_str.size_bytes())) { memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - else if (!out_ptr) + } else if (!out_ptr) { d_offsets[idx] = static_cast(nbytes); + } } };