diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index 9fd1768453a..2c594bb86a8 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,37 +62,49 @@ struct replace_regex_fn { if (!d_chars) d_offsets[idx] = 0; return; } - auto const d_str = d_strings.element(idx); - auto const nchars = d_str.length(); // number of characters in input string - auto nbytes = d_str.size_bytes(); // number of bytes in input string - auto mxn = maxrepl < 0 ? nchars : maxrepl; // max possible replaces for this string - auto in_ptr = d_str.data(); // input pointer (i) - auto out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; // output pointer (o) - size_type lpos = 0; - int32_t begin = 0; - int32_t end = static_cast(nchars); + + auto const d_str = d_strings.element(idx); + auto nbytes = d_str.size_bytes(); // number of bytes in input string + auto mxn = maxrepl < 0 ? d_str.length() + 1 : maxrepl; // max possible replaces for this string + auto in_ptr = d_str.data(); // input pointer (i) + auto out_ptr = d_chars ? d_chars + d_offsets[idx] // output pointer (o) + : nullptr; + size_type last_pos = 0; + int32_t begin = 0; // these are for calling prog.find + int32_t end = -1; // matches final word-boundary if at the end of the string + // copy input to output replacing strings as we go - while (mxn-- > 0) // maximum number of replaces - { - if (prog.is_empty() || prog.find(idx, d_str, begin, end) <= 0) - break; // no more matches - auto spos = d_str.byte_offset(begin); // get offset for these - auto epos = d_str.byte_offset(end); // character position values - nbytes += d_repl.size_bytes() - (epos - spos); // compute new size - if (out_ptr) // replace - { // i:bbbbsssseeee - out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); // o:bbbb - out_ptr = copy_string(out_ptr, d_repl); // o:bbbbrrrrrr - // out_ptr ---^ - lpos = epos; // i:bbbbsssseeee - } // in_ptr --^ - begin = end; - end = static_cast(nchars); + while (mxn-- > 0) { // maximum number of replaces + + if (prog.is_empty() || prog.find(idx, d_str, begin, end) <= 0) { + break; // no more matches + } + + auto const start_pos = d_str.byte_offset(begin); // get offset for these + auto const end_pos = d_str.byte_offset(end); // character position values + nbytes += d_repl.size_bytes() - (end_pos - start_pos); // and compute new size + + if (out_ptr) { // replace: + // i:bbbbsssseeee + out_ptr = copy_and_increment(out_ptr, // ^ + in_ptr + last_pos, // o:bbbb + start_pos - last_pos); // ^ + out_ptr = copy_string(out_ptr, d_repl); // o:bbbbrrrrrr + // out_ptr ---^ + last_pos = end_pos; // i:bbbbsssseeee + } // in_ptr --^ + + begin = end + (begin == end); + end = -1; } - if (out_ptr) // copy the remainder - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); // o:bbbbrrrrrreeee - else + + if (out_ptr) { + memcpy(out_ptr, // copy the remainder + in_ptr + last_pos, // o:bbbbrrrrrreeee + d_str.size_bytes() - last_pos); // ^ ^ + } else { d_offsets[idx] = static_cast(nbytes); + } } }; diff --git a/cpp/tests/strings/replace_regex_tests.cpp b/cpp/tests/strings/replace_regex_tests.cpp index eac06fa4588..ddbd9f5b3d6 100644 --- a/cpp/tests/strings/replace_regex_tests.cpp +++ b/cpp/tests/strings/replace_regex_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -145,6 +145,16 @@ TEST_F(StringsReplaceRegexTest, MultiReplacement) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, input); } +TEST_F(StringsReplaceRegexTest, WordBoundary) +{ + cudf::test::strings_column_wrapper input({"aba bcd\naba", "zéz", "A1B2-é3", "e é"}); + auto results = + cudf::strings::replace_re(cudf::strings_column_view(input), "\\b", cudf::string_scalar("X")); + cudf::test::strings_column_wrapper expected( + {"XabaX XbcdX\nXabaX", "XzézX", "XA1B2X-Xé3X", "XeX XéX"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected); +} + TEST_F(StringsReplaceRegexTest, Multiline) { auto const multiline = cudf::strings::regex_flags::MULTILINE;