diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 12b812d0bbe..c50464762c0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -547,6 +547,7 @@ add_library( src/strings/regex/regex_program.cpp src/strings/repeat_strings.cu src/strings/replace/backref_re.cu + src/strings/replace/multi.cu src/strings/replace/multi_re.cu src/strings/replace/replace.cu src/strings/replace/replace_re.cu diff --git a/cpp/benchmarks/string/replace.cpp b/cpp/benchmarks/string/replace.cpp index b25af14ec2a..cb570020f0e 100644 --- a/cpp/benchmarks/string/replace.cpp +++ b/cpp/benchmarks/string/replace.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,7 +69,7 @@ static void generate_bench_args(benchmark::internal::Benchmark* b) int const row_mult = 8; int const min_rowlen = 1 << 5; int const max_rowlen = 1 << 13; - int const len_mult = 4; + int const len_mult = 2; generate_string_bench_args(b, min_rows, max_rows, row_mult, min_rowlen, max_rowlen, len_mult); } diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu new file mode 100644 index 00000000000..92ace4e7bc7 --- /dev/null +++ b/cpp/src/strings/replace/multi.cu @@ -0,0 +1,500 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cudf { +namespace strings { +namespace detail { +namespace { + +/** + * @brief Threshold to decide on using string or character-parallel functions. + * + * If the average byte length of a string in a column exceeds this value then + * the character-parallel function is used. + * Otherwise, a regular string-parallel function is used. + * + * This value was found using the replace-multi benchmark results using an + * RTX A6000. + */ +constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 256; + +/** + * @brief Type used for holding the target position (first) and the + * target index (second). + */ +using target_pair = thrust::pair; + +/** + * @brief Helper functions for performing character-parallel replace + */ +struct replace_multi_parallel_fn { + __device__ char const* get_base_ptr() const + { + return d_strings.child(strings_column_view::chars_column_index).data(); + } + + __device__ size_type const* get_offsets_ptr() const + { + return d_strings.child(strings_column_view::offsets_column_index).data() + + d_strings.offset(); + } + + __device__ string_view const get_string(size_type idx) const + { + return d_strings.element(idx); + } + + __device__ string_view const get_replacement_string(size_type idx) const + { + return d_replacements.size() == 1 ? d_replacements[0] : d_replacements[idx]; + } + + __device__ bool is_valid(size_type idx) const { return d_strings.is_valid(idx); } + + /** + * @brief Returns the index of the target string found at the given byte position + * in the input strings column + * + * @param idx Index of the byte position in the chars column + * @param chars_bytes Number of bytes in the chars column + */ + __device__ thrust::optional has_target(size_type idx, size_type chars_bytes) const + { + auto const d_offsets = get_offsets_ptr(); + auto const d_chars = get_base_ptr() + d_offsets[0] + idx; + size_type str_idx = -1; + for (std::size_t t = 0; t < d_targets.size(); ++t) { + auto const d_tgt = d_targets[t]; + if (!d_tgt.empty() && (idx + d_tgt.size_bytes() <= chars_bytes) && + (d_tgt.compare(d_chars, d_tgt.size_bytes()) == 0)) { + if (str_idx < 0) { + auto const idx_itr = + thrust::upper_bound(thrust::seq, d_offsets, d_offsets + d_strings.size(), idx); + str_idx = thrust::distance(d_offsets, idx_itr) - 1; + } + auto const d_str = get_string(str_idx - d_offsets[0]); + if ((d_chars + d_tgt.size_bytes()) <= (d_str.data() + d_str.size_bytes())) { return t; } + } + } + return thrust::nullopt; + } + + /** + * @brief Count the number of strings that will be produced by the replace + * + * This includes segments of the string that are not replaced as well as those + * that are replaced. + * + * @param idx Index of the row in d_strings to be processed + * @param d_positions Positions of the targets found in the chars column + * @param d_targets_offsets Offsets identify which target positions go with the current string + * @return Number of substrings resulting from the replace operations on this row + */ + __device__ size_type count_strings(size_type idx, + target_pair const* d_positions, + size_type const* d_targets_offsets) const + { + if (!is_valid(idx)) { return 0; } + + auto const d_str = get_string(idx); + auto const d_str_end = d_str.data() + d_str.size_bytes(); + auto const base_ptr = get_base_ptr(); + auto const targets_positions = cudf::device_span( + d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]); + + size_type count = 1; // always at least one string + auto str_ptr = d_str.data(); + for (auto d_pair : targets_positions) { + auto const d_pos = d_pair.first; + auto const d_tgt = d_targets[d_pair.second]; + auto const tgt_ptr = base_ptr + d_pos; + if (str_ptr <= tgt_ptr && tgt_ptr < d_str_end) { + auto const keep_size = static_cast(thrust::distance(str_ptr, tgt_ptr)); + if (keep_size > 0) { count++; } // don't bother counting empty strings + + auto const d_repl = get_replacement_string(d_pair.second); + if (!d_repl.empty()) { count++; } + + str_ptr += keep_size + d_tgt.size_bytes(); + } + } + + return count; + } + + /** + * @brief Retrieve the strings for each row + * + * This will return string segments as string_index_pair objects for + * parts of the string that are not replaced interlaced with the + * appropriate replacement string where replacement targets are found. + * + * This function is called only once to produce both the string_index_pair objects + * and the output row size in bytes. + * + * @param idx Index of the row in d_strings + * @param d_offsets Offsets to identify where to store the results of the replace for this string + * @param d_positions The target positions found in the chars column + * @param d_targets_offsets The offsets to identify which target positions go with this string + * @param d_all_strings The output of all the produced string segments + * @return The size in bytes of the output string for this row + */ + __device__ size_type get_strings(size_type idx, + size_type const* d_offsets, + target_pair const* d_positions, + size_type const* d_targets_offsets, + string_index_pair* d_all_strings) const + { + if (!is_valid(idx)) { return 0; } + + auto const d_output = d_all_strings + d_offsets[idx]; + auto const d_str = get_string(idx); + auto const d_str_end = d_str.data() + d_str.size_bytes(); + auto const base_ptr = get_base_ptr(); + + auto const targets_positions = cudf::device_span( + d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]); + + size_type output_idx = 0; + size_type output_size = 0; + auto str_ptr = d_str.data(); + for (auto d_pair : targets_positions) { + auto const d_pos = d_pair.first; + auto const d_tgt = d_targets[d_pair.second]; + auto const tgt_ptr = base_ptr + d_pos; + if (str_ptr <= tgt_ptr && tgt_ptr < d_str_end) { + auto const keep_size = static_cast(thrust::distance(str_ptr, tgt_ptr)); + if (keep_size > 0) { d_output[output_idx++] = string_index_pair{str_ptr, keep_size}; } + output_size += keep_size; + + auto const d_repl = get_replacement_string(d_pair.second); + if (!d_repl.empty()) { + d_output[output_idx++] = string_index_pair{d_repl.data(), d_repl.size_bytes()}; + } + output_size += d_repl.size_bytes(); + + str_ptr += keep_size + d_tgt.size_bytes(); + } + } + // include any leftover parts of the string + if (str_ptr <= d_str_end) { + auto const left_size = static_cast(thrust::distance(str_ptr, d_str_end)); + d_output[output_idx] = string_index_pair{str_ptr, left_size}; + output_size += left_size; + } + return output_size; + } + + replace_multi_parallel_fn(column_device_view const& d_strings, + device_span d_targets, + device_span d_replacements) + : d_strings(d_strings), d_targets{d_targets}, d_replacements{d_replacements} + { + } + + protected: + column_device_view d_strings; + device_span d_targets; + device_span d_replacements; +}; + +/** + * @brief Used by the copy-if function to produce target_pair objects + * + * Using an inplace lambda caused a runtime crash in thrust::copy_if + * (this happens sometimes when passing device lambdas to thrust algorithms) + */ +struct pair_generator { + __device__ target_pair operator()(int idx) const + { + auto pos = fn.has_target(idx, chars_bytes); + return target_pair{idx, pos.value_or(-1)}; + } + replace_multi_parallel_fn fn; + size_type chars_bytes; +}; + +struct copy_if_fn { + __device__ bool operator()(target_pair pos) { return pos.second >= 0; } +}; + +std::unique_ptr replace_character_parallel(strings_column_view const& input, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto d_strings = column_device_view::create(input.parent(), stream); + + auto const strings_count = input.size(); + auto const chars_bytes = + cudf::detail::get_value(input.offsets(), input.offset() + strings_count, stream) - + cudf::detail::get_value(input.offsets(), input.offset(), stream); + + auto d_targets = + create_string_vector_from_column(targets, stream, rmm::mr::get_current_device_resource()); + auto d_replacements = + create_string_vector_from_column(repls, stream, rmm::mr::get_current_device_resource()); + + replace_multi_parallel_fn fn{*d_strings, d_targets, d_replacements}; + + // count the number of targets in the entire column + auto const target_count = thrust::count_if(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(chars_bytes), + [fn, chars_bytes] __device__(size_type idx) { + return fn.has_target(idx, chars_bytes).has_value(); + }); + // Create a vector of every target position in the chars column. + // These may include overlapping targets which will be resolved later. + auto targets_positions = rmm::device_uvector(target_count, stream); + auto d_positions = targets_positions.data(); + + auto const copy_itr = + cudf::detail::make_counting_transform_iterator(0, pair_generator{fn, chars_bytes}); + auto const copy_end = thrust::copy_if( + rmm::exec_policy(stream), copy_itr, copy_itr + chars_bytes, d_positions, copy_if_fn{}); + + // create a vector of offsets to each string's set of target positions + auto const targets_offsets = [&] { + auto string_indices = rmm::device_uvector(target_count, stream); + + auto const pos_itr = cudf::detail::make_counting_transform_iterator( + 0, [d_positions] __device__(auto idx) -> size_type { return d_positions[idx].first; }); + auto pos_count = std::distance(d_positions, copy_end); + + thrust::upper_bound(rmm::exec_policy(stream), + input.offsets_begin(), + input.offsets_end(), + pos_itr, + pos_itr + pos_count, + string_indices.begin()); + + // compute offsets per string + auto targets_offsets = rmm::device_uvector(strings_count + 1, stream); + auto d_targets_offsets = targets_offsets.data(); + + // memset to zero-out the target counts for any null-entries or strings with no targets + thrust::uninitialized_fill( + rmm::exec_policy(stream), targets_offsets.begin(), targets_offsets.end(), 0); + + // next, count the number of targets per string + auto d_string_indices = string_indices.data(); + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + target_count, + [d_string_indices, d_targets_offsets] __device__(size_type idx) { + auto const str_idx = d_string_indices[idx] - 1; + atomicAdd(d_targets_offsets + str_idx, 1); + }); + // finally, convert the counts into offsets + thrust::exclusive_scan(rmm::exec_policy(stream), + targets_offsets.begin(), + targets_offsets.end(), + targets_offsets.begin()); + return targets_offsets; + }(); + auto const d_targets_offsets = targets_offsets.data(); + + // compute the number of string segments produced by replace in each string + auto counts = rmm::device_uvector(strings_count, stream); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + counts.begin(), + [fn, d_positions, d_targets_offsets] __device__(size_type idx) -> size_type { + return fn.count_strings(idx, d_positions, d_targets_offsets); + }); + + // create offsets from the counts + auto offsets = + std::get<0>(cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr)); + auto const total_strings = + cudf::detail::get_value(offsets->view(), strings_count, stream); + auto const d_strings_offsets = offsets->view().data(); + + // build a vector of all the positions for all the strings + auto indices = rmm::device_uvector(total_strings, stream); + auto d_indices = indices.data(); + auto d_sizes = counts.data(); // reusing this vector to hold output sizes now + thrust::for_each_n( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + strings_count, + [fn, d_strings_offsets, d_positions, d_targets_offsets, d_indices, d_sizes] __device__( + size_type idx) { + d_sizes[idx] = + fn.get_strings(idx, d_strings_offsets, d_positions, d_targets_offsets, d_indices); + }); + + // use this utility to gather the string parts into a contiguous chars column + auto chars = make_strings_column(indices.begin(), indices.end(), stream, mr); + + // create offsets from the sizes + offsets = + std::get<0>(cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr)); + + // build the strings columns from the chars and offsets + return make_strings_column(strings_count, + std::move(offsets), + std::move(chars->release().children.back()), + input.null_count(), + copy_bitmask(input.parent(), stream, mr)); +} + +/** + * @brief Function logic for the replace_string_parallel + * + * Performs the multi-replace operation with a thread per string. + * This performs best on smaller strings. @see AVG_CHAR_BYTES_THRESHOLD + */ +struct replace_multi_fn { + column_device_view const d_strings; + column_device_view const d_targets; + column_device_view const d_repls; + int32_t* d_offsets{}; + char* d_chars{}; + + __device__ void operator()(size_type idx) + { + if (d_strings.is_null(idx)) { + if (!d_chars) { d_offsets[idx] = 0; } + return; + } + auto const d_str = d_strings.element(idx); + char const* in_ptr = d_str.data(); + + size_type bytes = d_str.size_bytes(); + size_type spos = 0; + size_type lpos = 0; + char* out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; + + // check each character against each target + while (spos < d_str.size_bytes()) { + for (int tgt_idx = 0; tgt_idx < d_targets.size(); ++tgt_idx) { + auto const d_tgt = d_targets.element(tgt_idx); + if ((d_tgt.size_bytes() <= (d_str.size_bytes() - spos)) && // check fit + (d_tgt.compare(in_ptr + spos, d_tgt.size_bytes()) == 0)) // and match + { + auto const d_repl = (d_repls.size() == 1) ? d_repls.element(0) + : d_repls.element(tgt_idx); + bytes += d_repl.size_bytes() - d_tgt.size_bytes(); + if (out_ptr) { + out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); + out_ptr = copy_string(out_ptr, d_repl); + lpos = spos + d_tgt.size_bytes(); + } + spos += d_tgt.size_bytes() - 1; + break; + } + } + ++spos; + } + if (out_ptr) // copy remainder + memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); + else + d_offsets[idx] = bytes; + } +}; + +std::unique_ptr replace_string_parallel(strings_column_view const& input, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto d_strings = column_device_view::create(input.parent(), stream); + auto d_targets = column_device_view::create(targets.parent(), stream); + auto d_replacements = column_device_view::create(repls.parent(), stream); + + auto children = cudf::strings::detail::make_strings_children( + replace_multi_fn{*d_strings, *d_targets, *d_replacements}, input.size(), stream, mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); +} + +} // namespace + +std::unique_ptr replace(strings_column_view const& input, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + if (input.is_empty()) { return make_empty_column(type_id::STRING); } + CUDF_EXPECTS(((targets.size() > 0) && (targets.null_count() == 0)), + "Parameters targets must not be empty and must not have nulls"); + CUDF_EXPECTS(((repls.size() > 0) && (repls.null_count() == 0)), + "Parameters repls must not be empty and must not have nulls"); + if (repls.size() > 1) + CUDF_EXPECTS(repls.size() == targets.size(), "Sizes for targets and repls must match"); + + return (input.size() == input.null_count() || + ((input.chars_size() / (input.size() - input.null_count())) < AVG_CHAR_BYTES_THRESHOLD)) + ? replace_string_parallel(input, targets, repls, stream, mr) + : replace_character_parallel(input, targets, repls, stream, mr); +} + +} // namespace detail + +// external API + +std::unique_ptr replace(strings_column_view const& strings, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::replace(strings, targets, repls, cudf::get_default_stream(), mr); +} + +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/replace.cu b/cpp/src/strings/replace/replace.cu index d1a377a4bda..3fc969a4c1f 100644 --- a/cpp/src/strings/replace/replace.cu +++ b/cpp/src/strings/replace/replace.cu @@ -704,92 +704,6 @@ std::unique_ptr replace_slice(strings_column_view const& strings, cudf::detail::copy_bitmask(strings.parent(), stream, mr)); } -namespace { -/** - * @brief Function logic for the replace_multi API. - * - * This will perform the multi-replace operation on each string. - */ -struct replace_multi_fn { - column_device_view const d_strings; - column_device_view const d_targets; - column_device_view const d_repls; - int32_t* d_offsets{}; - char* d_chars{}; - - __device__ void operator()(size_type idx) - { - if (d_strings.is_null(idx)) { - if (!d_chars) d_offsets[idx] = 0; - return; - } - auto const d_str = d_strings.element(idx); - char const* in_ptr = d_str.data(); - - size_type bytes = d_str.size_bytes(); - size_type spos = 0; - size_type lpos = 0; - char* out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; - - // check each character against each target - while (spos < d_str.size_bytes()) { - for (int tgt_idx = 0; tgt_idx < d_targets.size(); ++tgt_idx) { - auto const d_tgt = d_targets.element(tgt_idx); - if ((d_tgt.size_bytes() <= (d_str.size_bytes() - spos)) && // check fit - (d_tgt.compare(in_ptr + spos, d_tgt.size_bytes()) == 0)) // and match - { - auto const d_repl = (d_repls.size() == 1) ? d_repls.element(0) - : d_repls.element(tgt_idx); - bytes += d_repl.size_bytes() - d_tgt.size_bytes(); - if (out_ptr) { - out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); - out_ptr = copy_string(out_ptr, d_repl); - lpos = spos + d_tgt.size_bytes(); - } - spos += d_tgt.size_bytes() - 1; - break; - } - } - ++spos; - } - if (out_ptr) // copy remainder - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - else - d_offsets[idx] = bytes; - } -}; - -} // namespace - -std::unique_ptr replace(strings_column_view const& strings, - strings_column_view const& targets, - strings_column_view const& repls, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - CUDF_EXPECTS(((targets.size() > 0) && (targets.null_count() == 0)), - "Parameters targets must not be empty and must not have nulls"); - CUDF_EXPECTS(((repls.size() > 0) && (repls.null_count() == 0)), - "Parameters repls must not be empty and must not have nulls"); - if (repls.size() > 1) - CUDF_EXPECTS(repls.size() == targets.size(), "Sizes for targets and repls must match"); - - auto d_strings = column_device_view::create(strings.parent(), stream); - auto d_targets = column_device_view::create(targets.parent(), stream); - auto d_repls = column_device_view::create(repls.parent(), stream); - - // this utility calls the given functor to build the offsets and chars columns - auto children = cudf::strings::detail::make_strings_children( - replace_multi_fn{*d_strings, *d_targets, *d_repls}, strings.size(), stream, mr); - - return make_strings_column(strings.size(), - std::move(children.first), - std::move(children.second), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); -} - std::unique_ptr replace_nulls(strings_column_view const& strings, string_scalar const& repl, rmm::cuda_stream_view stream, @@ -854,14 +768,5 @@ std::unique_ptr replace_slice(strings_column_view const& strings, return detail::replace_slice(strings, repl, start, stop, cudf::get_default_stream(), mr); } -std::unique_ptr replace(strings_column_view const& strings, - strings_column_view const& targets, - strings_column_view const& repls, - rmm::mr::device_memory_resource* mr) -{ - CUDF_FUNC_RANGE(); - return detail::replace(strings, targets, repls, cudf::get_default_stream(), mr); -} - } // namespace strings } // namespace cudf diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index 32e097838c0..85185b2deab 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -290,28 +290,22 @@ TEST_F(StringsReplaceTest, ReplaceSlice) TEST_F(StringsReplaceTest, ReplaceSliceError) { - std::vector h_strings{"Héllo", "thesé", nullptr, "are not", "important", ""}; - cudf::test::strings_column_wrapper strings( - h_strings.begin(), - h_strings.end(), - thrust::make_transform_iterator(h_strings.begin(), [](auto str) { return str != nullptr; })); - auto strings_view = cudf::strings_column_view(strings); - EXPECT_THROW(cudf::strings::replace_slice(strings_view, cudf::string_scalar(""), 4, 1), - cudf::logic_error); + cudf::test::strings_column_wrapper input({"Héllo", "thesé", "are not", "important", ""}); + EXPECT_THROW( + cudf::strings::replace_slice(cudf::strings_column_view(input), cudf::string_scalar(""), 4, 1), + cudf::logic_error); } TEST_F(StringsReplaceTest, ReplaceMulti) { - auto strings = build_corpus(); - auto strings_view = cudf::strings_column_view(strings); + auto input = build_corpus(); + auto strings_view = cudf::strings_column_view(input); - std::vector h_targets{"the ", "a ", "to "}; - cudf::test::strings_column_wrapper targets(h_targets.begin(), h_targets.end()); + cudf::test::strings_column_wrapper targets({"the ", "a ", "to "}); auto targets_view = cudf::strings_column_view(targets); { - std::vector h_repls{"_ ", "A ", "2 "}; - cudf::test::strings_column_wrapper repls(h_repls.begin(), h_repls.end()); + cudf::test::strings_column_wrapper repls({"_ ", "A ", "2 "}); auto repls_view = cudf::strings_column_view(repls); auto results = cudf::strings::replace(strings_view, targets_view, repls_view); @@ -331,8 +325,7 @@ TEST_F(StringsReplaceTest, ReplaceMulti) } { - std::vector h_repls{"* "}; - cudf::test::strings_column_wrapper repls(h_repls.begin(), h_repls.end()); + cudf::test::strings_column_wrapper repls({"* "}); auto repls_view = cudf::strings_column_view(repls); auto results = cudf::strings::replace(strings_view, targets_view, repls_view); @@ -352,6 +345,129 @@ TEST_F(StringsReplaceTest, ReplaceMulti) } } +TEST_F(StringsReplaceTest, ReplaceMultiLong) +{ + // The length of the strings are to trigger the code path governed by the AVG_CHAR_BYTES_THRESHOLD + // setting in the multi.cu. + auto input = cudf::test::strings_column_wrapper( + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" + "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" + "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" + "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "2345678901234567890123456789", + "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" + "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" + "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" + "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "2345678901234567890123456789", + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + auto strings_view = cudf::strings_column_view(input); + + auto targets = cudf::test::strings_column_wrapper({"78901", "bananá", "ápple", "78"}); + auto targets_view = cudf::strings_column_view(targets); + + { + cudf::test::strings_column_wrapper repls({"x", "PEAR", "avocado", "$$"}); + auto repls_view = cudf::strings_column_view(repls); + + auto results = cudf::strings::replace(strings_view, targets_view, repls_view); + + cudf::test::strings_column_wrapper expected( + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } + + { + cudf::test::strings_column_wrapper repls({"*"}); + auto repls_view = cudf::strings_column_view(repls); + + auto results = cudf::strings::replace(strings_view, targets_view, repls_view); + + cudf::test::strings_column_wrapper expected( + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "0123456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*9", + "0123456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*9", + "Test string for overlap check: banana* * ** ban* * * Test string for overlap check: " + "banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * * Test string for " + "overlap check: banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * *", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } + + { + targets = + cudf::test::strings_column_wrapper({"01234567890123456789012345678901234567890123456789012345" + "6789012345678901234567890123456789012" + "34567890123456789012345678901234567890123456789012345678" + "9012345678901234567890123456789012345" + "67890123456789012345678901234567890123456789012345678901" + "2345678901234567890123456789012345678" + "90123456789012345678901234567890123456789012345678901234" + "5678901234567890123456789012345678901" + "2345678901234567890123456789", + "78"}); + targets_view = cudf::strings_column_view(targets); + auto repls = cudf::test::strings_column_wrapper({""}); + auto repls_view = cudf::strings_column_view(repls); + + auto results = cudf::strings::replace(strings_view, targets_view, repls_view); + + cudf::test::strings_column_wrapper expected( + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "", + "", + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } +} + TEST_F(StringsReplaceTest, EmptyStringsColumn) { cudf::column_view zero_size_strings_column(