From 3ea64000c1a6dc610f5d2f8435c1bb24f8f8493b Mon Sep 17 00:00:00 2001 From: David Wendt Date: Mon, 27 Feb 2023 17:52:13 -0500 Subject: [PATCH 1/9] Improve performance for replace-multi for long strings --- cpp/CMakeLists.txt | 1 + cpp/src/strings/replace/multi.cu | 422 ++++++++++++++++++++++++++++ cpp/src/strings/replace/replace.cu | 95 ------- cpp/tests/strings/replace_tests.cpp | 91 ++++-- 4 files changed, 498 insertions(+), 111 deletions(-) create mode 100644 cpp/src/strings/replace/multi.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d402a47628c..1928dceae21 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -541,6 +541,7 @@ add_library( src/strings/regex/regex_program.cpp src/strings/repeat_strings.cu src/strings/replace/backref_re.cu + src/strings/replace/multi.cu src/strings/replace/multi_re.cu src/strings/replace/replace.cu src/strings/replace/replace_re.cu diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu new file mode 100644 index 00000000000..5a4602998db --- /dev/null +++ b/cpp/src/strings/replace/multi.cu @@ -0,0 +1,422 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cudf { +namespace strings { +namespace detail { +namespace { + +/** + * @brief + */ +constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 64; + +using target_pair = thrust::pair; + +struct replace_multi_parallel_fn { + __device__ char const* get_base_ptr() const + { + return d_strings.child(strings_column_view::chars_column_index).data(); + } + + __device__ string_view const get_string(size_type idx) const + { + return d_strings.element(idx); + } + + __device__ bool is_valid(size_type idx) const { return d_strings.is_valid(idx); } + + __device__ thrust::optional has_target(size_type idx, + size_type const* d_offsets, + size_type chars_bytes) const + { + auto const d_chars = get_base_ptr() + d_offsets[0] + idx; + size_type str_idx = -1; + for (std::size_t t = 0; t < d_targets.size(); ++t) { + auto const d_tgt = d_targets[t]; + if (!d_tgt.empty() && (idx + d_tgt.size_bytes() <= chars_bytes) && + (d_tgt.compare(d_chars, d_tgt.size_bytes()) == 0)) { + if (str_idx < 0) { + auto const idx_itr = + thrust::upper_bound(thrust::seq, d_offsets, d_offsets + d_strings.size(), idx); + str_idx = thrust::distance(d_offsets, idx_itr) - 1; + } + auto d_str = get_string(str_idx - d_offsets[0]); + if ((d_chars + d_tgt.size_bytes()) <= (d_str.data() + d_str.size_bytes())) { return t; } + } + } + return thrust::nullopt; + } + + __device__ size_type count_strings(size_type idx, + target_pair const* d_positions, + size_type const* d_targets_offsets) const + { + if (!is_valid(idx)) { return 0; } + + auto const d_str = get_string(idx); + auto const d_str_end = d_str.data() + d_str.size_bytes(); + auto const base_ptr = get_base_ptr(); //+ delim_size - 1; + auto const targets_positions = cudf::device_span( + d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]); + + size_type count = 1; + auto str_ptr = d_str.data(); + for (auto d_pair : targets_positions) { + auto const d_pos = d_pair.first; + auto const d_tgt = d_targets[d_pair.second]; + auto const tgt_ptr = base_ptr + d_pos; + if (str_ptr <= tgt_ptr && tgt_ptr < d_str_end) { + auto const keep_size = static_cast(thrust::distance(str_ptr, tgt_ptr)); + if (keep_size > 0) { count++; } + + auto const d_repl = + d_replacements.size() == 1 ? d_replacements[0] : d_replacements[d_pair.second]; + if (!d_repl.empty()) { count++; } + + str_ptr += keep_size + d_tgt.size_bytes(); + } + } + // if (str_ptr + 1 < d_str_end) count++; + return count; + } + + __device__ size_type get_strings(size_type idx, + size_type const* d_offsets, + target_pair const* d_positions, + size_type const* d_targets_offsets, + string_index_pair* d_all_strings) const + { + if (!is_valid(idx)) { return 0; } + + auto const d_output = d_all_strings + d_offsets[idx]; + auto const d_output_count = d_offsets[idx + 1] - d_offsets[idx]; + + auto const d_str = get_string(idx); + if (d_output_count == 1) { + d_output[0] = string_index_pair{d_str.data(), d_str.size_bytes()}; + return d_str.size_bytes(); + } + + auto const d_str_end = d_str.data() + d_str.size_bytes(); + auto const base_ptr = get_base_ptr(); //+ delim_size - 1; + auto const targets_positions = cudf::device_span( + d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]); + + size_type output_idx = 0; + size_type output_size = 0; + auto str_ptr = d_str.data(); + for (auto d_pair : targets_positions) { + auto const d_pos = d_pair.first; + auto const d_tgt = d_targets[d_pair.second]; + auto const tgt_ptr = base_ptr + d_pos; + if (str_ptr <= tgt_ptr && tgt_ptr < d_str_end) { + auto const keep_size = static_cast(thrust::distance(str_ptr, tgt_ptr)); + if (keep_size > 0) { d_output[output_idx++] = string_index_pair{str_ptr, keep_size}; } + output_size += keep_size; + + auto const d_repl = + d_replacements.size() == 1 ? d_replacements[0] : d_replacements[d_pair.second]; + if (!d_repl.empty()) { + d_output[output_idx++] = string_index_pair{d_repl.data(), d_repl.size_bytes()}; + } + output_size += d_repl.size_bytes(); + + str_ptr += keep_size + d_tgt.size_bytes(); + } + } + // include any leftover parts of the string + if (str_ptr <= d_str_end) { + auto const left_size = static_cast(thrust::distance(str_ptr, d_str_end)); + d_output[output_idx++] = string_index_pair{str_ptr, left_size}; + output_size += left_size; + } + return output_size; + } + + replace_multi_parallel_fn(column_device_view const& d_strings, + device_span const d_targets, + device_span const d_replacements) + : d_strings(d_strings), d_targets{d_targets}, d_replacements{d_replacements} + { + } + + protected: + column_device_view d_strings; + device_span d_targets; + device_span d_replacements; +}; + +/** + * @brief Function logic for the replace_multi API. + * + * This will perform the multi-replace operation on each string. + */ +struct replace_multi_fn { + column_device_view const d_strings; + column_device_view const d_targets; + column_device_view const d_repls; + int32_t* d_offsets{}; + char* d_chars{}; + + __device__ void operator()(size_type idx) + { + if (d_strings.is_null(idx)) { + if (!d_chars) d_offsets[idx] = 0; + return; + } + auto const d_str = d_strings.element(idx); + char const* in_ptr = d_str.data(); + + size_type bytes = d_str.size_bytes(); + size_type spos = 0; + size_type lpos = 0; + char* out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; + + // check each character against each target + while (spos < d_str.size_bytes()) { + for (int tgt_idx = 0; tgt_idx < d_targets.size(); ++tgt_idx) { + auto const d_tgt = d_targets.element(tgt_idx); + if ((d_tgt.size_bytes() <= (d_str.size_bytes() - spos)) && // check fit + (d_tgt.compare(in_ptr + spos, d_tgt.size_bytes()) == 0)) // and match + { + auto const d_repl = (d_repls.size() == 1) ? d_repls.element(0) + : d_repls.element(tgt_idx); + bytes += d_repl.size_bytes() - d_tgt.size_bytes(); + if (out_ptr) { + out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); + out_ptr = copy_string(out_ptr, d_repl); + lpos = spos + d_tgt.size_bytes(); + } + spos += d_tgt.size_bytes() - 1; + break; + } + } + ++spos; + } + if (out_ptr) // copy remainder + memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); + else + d_offsets[idx] = bytes; + } +}; + +} // namespace + +std::unique_ptr replace(strings_column_view const& input, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + if (input.is_empty()) return make_empty_column(type_id::STRING); + CUDF_EXPECTS(((targets.size() > 0) && (targets.null_count() == 0)), + "Parameters targets must not be empty and must not have nulls"); + CUDF_EXPECTS(((repls.size() > 0) && (repls.null_count() == 0)), + "Parameters repls must not be empty and must not have nulls"); + if (repls.size() > 1) + CUDF_EXPECTS(repls.size() == targets.size(), "Sizes for targets and repls must match"); + + auto d_strings = column_device_view::create(input.parent(), stream); + + if (input.size() == input.null_count() || + ((input.chars_size() / (input.size() - input.null_count())) < AVG_CHAR_BYTES_THRESHOLD)) { + auto d_targets = column_device_view::create(targets.parent(), stream); + auto d_replacements = column_device_view::create(repls.parent(), stream); + + auto children = cudf::strings::detail::make_strings_children( + replace_multi_fn{*d_strings, *d_targets, *d_replacements}, input.size(), stream, mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); + } + + auto const strings_count = input.size(); + auto const chars_bytes = + cudf::detail::get_value(input.offsets(), input.offset() + strings_count, stream) - + cudf::detail::get_value(input.offsets(), input.offset(), stream); + + auto d_offsets = input.offsets_begin(); + + auto d_targets = + create_string_vector_from_column(targets, stream, rmm::mr::get_current_device_resource()); + auto d_replacements = + create_string_vector_from_column(repls, stream, rmm::mr::get_current_device_resource()); + + replace_multi_parallel_fn fn{*d_strings, d_targets, d_replacements}; + + // count the number of targets in the entire column + auto const target_count = + thrust::count_if(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(chars_bytes), + [fn, d_offsets, chars_bytes] __device__(size_type idx) { + return fn.has_target(idx, d_offsets, chars_bytes).has_value(); + }); + // Create a vector of every target position in the chars column. + // These may include overlapping targets which will be resolved later. + auto targets_positions = rmm::device_uvector(target_count, stream); + auto d_positions = targets_positions.data(); + + auto copy_itr = cudf::detail::make_counting_transform_iterator( + 0, [fn, d_offsets, chars_bytes] __device__(auto idx) -> target_pair { + auto pos = fn.has_target(idx, d_offsets, chars_bytes); + return target_pair{idx, pos.value_or(-1)}; + }); + auto const copy_end = thrust::copy_if(rmm::exec_policy(stream), + copy_itr, + copy_itr + chars_bytes, + targets_positions.begin(), + [] __device__(auto pos) { return pos.second >= 0; }); + + // create a vector of offsets to each string's set of target positions + auto const targets_offsets = [&] { + auto string_indices = rmm::device_uvector(target_count, stream); + + auto pos_itr = cudf::detail::make_counting_transform_iterator( + 0, [d_positions] __device__(auto idx) -> size_type { return d_positions[idx].first; }); + auto pos_count = std::distance(d_positions, copy_end); + + thrust::upper_bound(rmm::exec_policy(stream), + d_offsets, + d_offsets + strings_count, + pos_itr, + pos_itr + pos_count, + string_indices.begin()); + + // compute offsets per string + auto targets_offsets = rmm::device_uvector(strings_count + 1, stream); + auto d_targets_offsets = targets_offsets.data(); + + // memset to zero-out the target counts for any null-entries or strings with no targets + CUDF_CUDA_TRY(cudaMemsetAsync( + d_targets_offsets, 0, targets_offsets.size() * sizeof(size_type), stream.value())); + + // next, count the number of targes per string + auto d_string_indices = string_indices.data(); + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + target_count, + [d_string_indices, d_targets_offsets] __device__(size_type idx) { + auto const str_idx = d_string_indices[idx] - 1; + atomicAdd(d_targets_offsets + str_idx, 1); + }); + // finally, convert the counts into offsets + thrust::exclusive_scan(rmm::exec_policy(stream), + targets_offsets.begin(), + targets_offsets.end(), + targets_offsets.begin()); + return targets_offsets; + }(); + auto const d_targets_offsets = targets_offsets.data(); + + // compute the output count of each output string + auto counts = rmm::device_uvector(strings_count, stream); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + counts.begin(), + [fn, d_positions, d_targets_offsets] __device__(size_type idx) -> size_type { + return fn.count_strings(idx, d_positions, d_targets_offsets); + }); + + // create offsets from the counts + auto offsets = + std::get<0>(cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr)); + auto const total_strings = + cudf::detail::get_value(offsets->view(), strings_count, stream); + auto const d_strings_offsets = offsets->view().data(); + + // build a vector of all the positions for all the strings + auto indices = rmm::device_uvector(total_strings, stream); + auto d_indices = indices.data(); + auto d_sizes = counts.data(); + thrust::for_each_n( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + strings_count, + [fn, d_strings_offsets, d_positions, d_targets_offsets, d_indices, d_sizes] __device__( + size_type idx) { + d_sizes[idx] = + fn.get_strings(idx, d_strings_offsets, d_positions, d_targets_offsets, d_indices); + }); + + // use this utility to gather the string parts into a contiguous chars column + auto chars = make_strings_column(indices.begin(), indices.end(), stream, mr); + + // create offsets from the sizes + offsets = + std::get<0>(cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr)); + + // build the strings columns from the chars and offsets + return make_strings_column(strings_count, + std::move(offsets), + std::move(chars->release().children.back()), + input.null_count(), + copy_bitmask(input.parent(), stream, mr)); +} + +} // namespace detail + +// external API + +std::unique_ptr replace(strings_column_view const& strings, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::replace(strings, targets, repls, cudf::get_default_stream(), mr); +} + +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/replace.cu b/cpp/src/strings/replace/replace.cu index d1a377a4bda..3fc969a4c1f 100644 --- a/cpp/src/strings/replace/replace.cu +++ b/cpp/src/strings/replace/replace.cu @@ -704,92 +704,6 @@ std::unique_ptr replace_slice(strings_column_view const& strings, cudf::detail::copy_bitmask(strings.parent(), stream, mr)); } -namespace { -/** - * @brief Function logic for the replace_multi API. - * - * This will perform the multi-replace operation on each string. - */ -struct replace_multi_fn { - column_device_view const d_strings; - column_device_view const d_targets; - column_device_view const d_repls; - int32_t* d_offsets{}; - char* d_chars{}; - - __device__ void operator()(size_type idx) - { - if (d_strings.is_null(idx)) { - if (!d_chars) d_offsets[idx] = 0; - return; - } - auto const d_str = d_strings.element(idx); - char const* in_ptr = d_str.data(); - - size_type bytes = d_str.size_bytes(); - size_type spos = 0; - size_type lpos = 0; - char* out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; - - // check each character against each target - while (spos < d_str.size_bytes()) { - for (int tgt_idx = 0; tgt_idx < d_targets.size(); ++tgt_idx) { - auto const d_tgt = d_targets.element(tgt_idx); - if ((d_tgt.size_bytes() <= (d_str.size_bytes() - spos)) && // check fit - (d_tgt.compare(in_ptr + spos, d_tgt.size_bytes()) == 0)) // and match - { - auto const d_repl = (d_repls.size() == 1) ? d_repls.element(0) - : d_repls.element(tgt_idx); - bytes += d_repl.size_bytes() - d_tgt.size_bytes(); - if (out_ptr) { - out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); - out_ptr = copy_string(out_ptr, d_repl); - lpos = spos + d_tgt.size_bytes(); - } - spos += d_tgt.size_bytes() - 1; - break; - } - } - ++spos; - } - if (out_ptr) // copy remainder - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - else - d_offsets[idx] = bytes; - } -}; - -} // namespace - -std::unique_ptr replace(strings_column_view const& strings, - strings_column_view const& targets, - strings_column_view const& repls, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - CUDF_EXPECTS(((targets.size() > 0) && (targets.null_count() == 0)), - "Parameters targets must not be empty and must not have nulls"); - CUDF_EXPECTS(((repls.size() > 0) && (repls.null_count() == 0)), - "Parameters repls must not be empty and must not have nulls"); - if (repls.size() > 1) - CUDF_EXPECTS(repls.size() == targets.size(), "Sizes for targets and repls must match"); - - auto d_strings = column_device_view::create(strings.parent(), stream); - auto d_targets = column_device_view::create(targets.parent(), stream); - auto d_repls = column_device_view::create(repls.parent(), stream); - - // this utility calls the given functor to build the offsets and chars columns - auto children = cudf::strings::detail::make_strings_children( - replace_multi_fn{*d_strings, *d_targets, *d_repls}, strings.size(), stream, mr); - - return make_strings_column(strings.size(), - std::move(children.first), - std::move(children.second), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); -} - std::unique_ptr replace_nulls(strings_column_view const& strings, string_scalar const& repl, rmm::cuda_stream_view stream, @@ -854,14 +768,5 @@ std::unique_ptr replace_slice(strings_column_view const& strings, return detail::replace_slice(strings, repl, start, stop, cudf::get_default_stream(), mr); } -std::unique_ptr replace(strings_column_view const& strings, - strings_column_view const& targets, - strings_column_view const& repls, - rmm::mr::device_memory_resource* mr) -{ - CUDF_FUNC_RANGE(); - return detail::replace(strings, targets, repls, cudf::get_default_stream(), mr); -} - } // namespace strings } // namespace cudf diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index 32e097838c0..1048334acc0 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -173,6 +173,22 @@ TEST_F(StringsReplaceTest, ReplaceTargetOverlap) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } +TEST_F(StringsReplaceTest, ReplaceTargetOverlap2) +{ + auto input = cudf::test::strings_column_wrapper({"banana", "nanananananana"}); + auto strings_view = cudf::strings_column_view(input); + + auto stream = cudf::get_default_stream(); + auto mr = rmm::mr::get_current_device_resource(); + + auto results = cudf::strings::detail::replace( + strings_view, cudf::string_scalar("nana"), cudf::string_scalar(""), -1, stream, mr); + cudf::test::print(results->view()); + results = cudf::strings::detail::replace( + strings_view, cudf::string_scalar("nana"), cudf::string_scalar(""), -1, stream, mr); + cudf::test::print(results->view()); +} + TEST_F(StringsReplaceTest, ReplaceTargetOverlapsStrings) { auto input = build_corpus(); @@ -290,28 +306,22 @@ TEST_F(StringsReplaceTest, ReplaceSlice) TEST_F(StringsReplaceTest, ReplaceSliceError) { - std::vector h_strings{"Héllo", "thesé", nullptr, "are not", "important", ""}; - cudf::test::strings_column_wrapper strings( - h_strings.begin(), - h_strings.end(), - thrust::make_transform_iterator(h_strings.begin(), [](auto str) { return str != nullptr; })); - auto strings_view = cudf::strings_column_view(strings); - EXPECT_THROW(cudf::strings::replace_slice(strings_view, cudf::string_scalar(""), 4, 1), - cudf::logic_error); + cudf::test::strings_column_wrapper input({"Héllo", "thesé", "are not", "important", ""}); + EXPECT_THROW( + cudf::strings::replace_slice(cudf::strings_column_view(input), cudf::string_scalar(""), 4, 1), + cudf::logic_error); } TEST_F(StringsReplaceTest, ReplaceMulti) { - auto strings = build_corpus(); - auto strings_view = cudf::strings_column_view(strings); + auto input = build_corpus(); + auto strings_view = cudf::strings_column_view(input); - std::vector h_targets{"the ", "a ", "to "}; - cudf::test::strings_column_wrapper targets(h_targets.begin(), h_targets.end()); + cudf::test::strings_column_wrapper targets({"the ", "a ", "to "}); auto targets_view = cudf::strings_column_view(targets); { - std::vector h_repls{"_ ", "A ", "2 "}; - cudf::test::strings_column_wrapper repls(h_repls.begin(), h_repls.end()); + cudf::test::strings_column_wrapper repls({"_ ", "A ", "2 "}); auto repls_view = cudf::strings_column_view(repls); auto results = cudf::strings::replace(strings_view, targets_view, repls_view); @@ -331,8 +341,7 @@ TEST_F(StringsReplaceTest, ReplaceMulti) } { - std::vector h_repls{"* "}; - cudf::test::strings_column_wrapper repls(h_repls.begin(), h_repls.end()); + cudf::test::strings_column_wrapper repls({"* "}); auto repls_view = cudf::strings_column_view(repls); auto results = cudf::strings::replace(strings_view, targets_view, repls_view); @@ -352,6 +361,56 @@ TEST_F(StringsReplaceTest, ReplaceMulti) } } +TEST_F(StringsReplaceTest, ReplaceMultiLong) +{ + auto input = cudf::test::strings_column_wrapper( + {"This string needs to be very long to trigger the long-replace internal functions.", + "01234567890123456789012345678901234567890123456789012345678901234567890123456789", + "01234567890123456789012345678901234567890123456789012345678901234567890123456789", + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + auto strings_view = cudf::strings_column_view(input); + + cudf::test::strings_column_wrapper targets({"901", "bananá", "ápple"}); + auto targets_view = cudf::strings_column_view(targets); + + { + cudf::test::strings_column_wrapper repls({"x", "PEAR", "avocado"}); + auto repls_view = cudf::strings_column_view(repls); + + auto results = cudf::strings::replace(strings_view, targets_view, repls_view); + + cudf::test::strings_column_wrapper expected( + {"This string needs to be very long to trigger the long-replace internal functions.", + "012345678x2345678x2345678x2345678x2345678x2345678x2345678x23456789", + "012345678x2345678x2345678x2345678x2345678x2345678x2345678x23456789", + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } + + { + cudf::test::strings_column_wrapper repls({"*"}); + auto repls_view = cudf::strings_column_view(repls); + + auto results = cudf::strings::replace(strings_view, targets_view, repls_view); + + cudf::test::strings_column_wrapper expected( + {"This string needs to be very long to trigger the long-replace internal functions.", + "012345678*2345678*2345678*2345678*2345678*2345678*2345678*23456789", + "012345678*2345678*2345678*2345678*2345678*2345678*2345678*23456789", + "Test string for overlap check: banana* * ** ban* * *", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } +} + TEST_F(StringsReplaceTest, EmptyStringsColumn) { cudf::column_view zero_size_strings_column( From 530faa94d5e0923bd0229daf21360b5c155b8482 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Tue, 28 Feb 2023 17:24:09 -0500 Subject: [PATCH 2/9] fix copy-if by converting lambda to functor --- cpp/benchmarks/string/replace.cpp | 2 +- cpp/src/strings/replace/multi.cu | 149 ++++++++++++++++++++-------- cpp/tests/strings/replace_tests.cpp | 2 +- 3 files changed, 110 insertions(+), 43 deletions(-) diff --git a/cpp/benchmarks/string/replace.cpp b/cpp/benchmarks/string/replace.cpp index b25af14ec2a..2794acd5945 100644 --- a/cpp/benchmarks/string/replace.cpp +++ b/cpp/benchmarks/string/replace.cpp @@ -69,7 +69,7 @@ static void generate_bench_args(benchmark::internal::Benchmark* b) int const row_mult = 8; int const min_rowlen = 1 << 5; int const max_rowlen = 1 << 13; - int const len_mult = 4; + int const len_mult = 2; generate_string_bench_args(b, min_rows, max_rows, row_mult, min_rowlen, max_rowlen, len_mult); } diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index 5a4602998db..fa8727cddcb 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -53,31 +53,61 @@ namespace detail { namespace { /** - * @brief + * @brief Threshold to decide on using string or character-parallel functions. + * + * If the average byte length of a string in a column exceeds this value then + * the character-parallel function is used. + * Otherwise, a regular string-parallel function is used. + * + * This value was found using the replace-multi benchmark results. */ -constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 64; +constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 256; +/** + * @brief Type used for holding the target position (first) and the + * target index (second). + */ using target_pair = thrust::pair; +/** + * @brief Helper functions for performing character-parallel replace + */ struct replace_multi_parallel_fn { __device__ char const* get_base_ptr() const { return d_strings.child(strings_column_view::chars_column_index).data(); } + __device__ size_type const* get_offsets_ptr() const + { + return d_strings.child(strings_column_view::offsets_column_index).data() + + d_strings.offset(); + } + __device__ string_view const get_string(size_type idx) const { return d_strings.element(idx); } + __device__ string_view const get_replacement_string(size_type idx) const + { + return d_replacements.size() == 1 ? d_replacements[0] : d_replacements[idx]; + } + __device__ bool is_valid(size_type idx) const { return d_strings.is_valid(idx); } - __device__ thrust::optional has_target(size_type idx, - size_type const* d_offsets, - size_type chars_bytes) const + /** + * @brief Returns the index of the target string found at the given byte position + * in the input strings column + * + * @param idx Index of the byte position in the chars column + * @param chars_bytes Number of bytes in the chars column + */ + __device__ thrust::optional has_target(size_type idx, size_type chars_bytes) const { - auto const d_chars = get_base_ptr() + d_offsets[0] + idx; - size_type str_idx = -1; + auto const d_offsets = get_offsets_ptr(); + auto const d_chars = get_base_ptr() + d_offsets[0] + idx; + size_type str_idx = -1; for (std::size_t t = 0; t < d_targets.size(); ++t) { auto const d_tgt = d_targets[t]; if (!d_tgt.empty() && (idx + d_tgt.size_bytes() <= chars_bytes) && @@ -94,6 +124,17 @@ struct replace_multi_parallel_fn { return thrust::nullopt; } + /** + * @brief Count the number of strings that will be produced by the replace + * + * This includes segments of the string that are not replaced as well as those + * that are replaced. + * + * @param idx Index of the row in d_strings to be processed + * @param d_positions Positions of the targets found in the chars column + * @param d_targets_offsets Offsets identify which target positions go with the current string + * @return Number of substrings resulting from the replace operations on this row + */ __device__ size_type count_strings(size_type idx, target_pair const* d_positions, size_type const* d_targets_offsets) const @@ -102,11 +143,11 @@ struct replace_multi_parallel_fn { auto const d_str = get_string(idx); auto const d_str_end = d_str.data() + d_str.size_bytes(); - auto const base_ptr = get_base_ptr(); //+ delim_size - 1; + auto const base_ptr = get_base_ptr(); auto const targets_positions = cudf::device_span( d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]); - size_type count = 1; + size_type count = 1; // always at least one string auto str_ptr = d_str.data(); for (auto d_pair : targets_positions) { auto const d_pos = d_pair.first; @@ -114,19 +155,35 @@ struct replace_multi_parallel_fn { auto const tgt_ptr = base_ptr + d_pos; if (str_ptr <= tgt_ptr && tgt_ptr < d_str_end) { auto const keep_size = static_cast(thrust::distance(str_ptr, tgt_ptr)); - if (keep_size > 0) { count++; } + if (keep_size > 0) { count++; } // don't bother counting empty strings - auto const d_repl = - d_replacements.size() == 1 ? d_replacements[0] : d_replacements[d_pair.second]; + auto const d_repl = get_replacement_string(d_pair.second); if (!d_repl.empty()) { count++; } str_ptr += keep_size + d_tgt.size_bytes(); } } - // if (str_ptr + 1 < d_str_end) count++; + return count; } + /** + * @brief Retrieve the strings for each row + * + * This will return string segments as string_index_pair objects for + * parts of the string that are not replaced interlaced with the + * appropriate replacement string where replacement targets are found. + * + * This function is called only once to produce both the string_index_pair objects + * and the output row size in bytes. + * + * @param idx Index of the row in d_strings + * @param d_offsets Offsets to identify where to store the results of the replace for this string + * @param d_positions The target positions found in the chars column + * @param d_targets_offsets The offsets to identify which target positions go with this string + * @param d_all_strings The output of all the produced string segments + * @return The size in bytes of the output string for this row + */ __device__ size_type get_strings(size_type idx, size_type const* d_offsets, target_pair const* d_positions, @@ -161,8 +218,7 @@ struct replace_multi_parallel_fn { if (keep_size > 0) { d_output[output_idx++] = string_index_pair{str_ptr, keep_size}; } output_size += keep_size; - auto const d_repl = - d_replacements.size() == 1 ? d_replacements[0] : d_replacements[d_pair.second]; + auto const d_repl = get_replacement_string(d_pair.second); if (!d_repl.empty()) { d_output[output_idx++] = string_index_pair{d_repl.data(), d_repl.size_bytes()}; } @@ -173,8 +229,8 @@ struct replace_multi_parallel_fn { } // include any leftover parts of the string if (str_ptr <= d_str_end) { - auto const left_size = static_cast(thrust::distance(str_ptr, d_str_end)); - d_output[output_idx++] = string_index_pair{str_ptr, left_size}; + auto const left_size = static_cast(thrust::distance(str_ptr, d_str_end)); + d_output[output_idx] = string_index_pair{str_ptr, left_size}; output_size += left_size; } return output_size; @@ -193,6 +249,26 @@ struct replace_multi_parallel_fn { device_span d_replacements; }; +/** + * @brief Used by the copy-if function to produce target_pair objects + * + * Using an inplace lambda caused a runtime crash in thrust::copy_if + * (this happens sometimes with passing device lambdas to thrust algorithms) + */ +struct pair_generator { + __device__ target_pair operator()(int idx) const + { + auto pos = fn.has_target(idx, chars_bytes); + return target_pair{idx, pos.value_or(-1)}; + } + replace_multi_parallel_fn fn; + size_type chars_bytes; +}; + +struct copy_if_fn { + __device__ bool operator()(target_pair pos) { return pos.second >= 0; } +}; + /** * @brief Function logic for the replace_multi API. * @@ -285,8 +361,6 @@ std::unique_ptr replace(strings_column_view const& input, cudf::detail::get_value(input.offsets(), input.offset() + strings_count, stream) - cudf::detail::get_value(input.offsets(), input.offset(), stream); - auto d_offsets = input.offsets_begin(); - auto d_targets = create_string_vector_from_column(targets, stream, rmm::mr::get_current_device_resource()); auto d_replacements = @@ -295,28 +369,21 @@ std::unique_ptr replace(strings_column_view const& input, replace_multi_parallel_fn fn{*d_strings, d_targets, d_replacements}; // count the number of targets in the entire column - auto const target_count = - thrust::count_if(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(chars_bytes), - [fn, d_offsets, chars_bytes] __device__(size_type idx) { - return fn.has_target(idx, d_offsets, chars_bytes).has_value(); - }); + auto const target_count = thrust::count_if(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(chars_bytes), + [fn, chars_bytes] __device__(size_type idx) { + return fn.has_target(idx, chars_bytes).has_value(); + }); // Create a vector of every target position in the chars column. // These may include overlapping targets which will be resolved later. auto targets_positions = rmm::device_uvector(target_count, stream); auto d_positions = targets_positions.data(); - auto copy_itr = cudf::detail::make_counting_transform_iterator( - 0, [fn, d_offsets, chars_bytes] __device__(auto idx) -> target_pair { - auto pos = fn.has_target(idx, d_offsets, chars_bytes); - return target_pair{idx, pos.value_or(-1)}; - }); - auto const copy_end = thrust::copy_if(rmm::exec_policy(stream), - copy_itr, - copy_itr + chars_bytes, - targets_positions.begin(), - [] __device__(auto pos) { return pos.second >= 0; }); + auto copy_itr = + cudf::detail::make_counting_transform_iterator(0, pair_generator{fn, chars_bytes}); + auto const copy_end = thrust::copy_if( + rmm::exec_policy(stream), copy_itr, copy_itr + chars_bytes, d_positions, copy_if_fn{}); // create a vector of offsets to each string's set of target positions auto const targets_offsets = [&] { @@ -327,8 +394,8 @@ std::unique_ptr replace(strings_column_view const& input, auto pos_count = std::distance(d_positions, copy_end); thrust::upper_bound(rmm::exec_policy(stream), - d_offsets, - d_offsets + strings_count, + input.offsets_begin(), + input.offsets_end(), pos_itr, pos_itr + pos_count, string_indices.begin()); @@ -341,7 +408,7 @@ std::unique_ptr replace(strings_column_view const& input, CUDF_CUDA_TRY(cudaMemsetAsync( d_targets_offsets, 0, targets_offsets.size() * sizeof(size_type), stream.value())); - // next, count the number of targes per string + // next, count the number of targets per string auto d_string_indices = string_indices.data(); thrust::for_each_n(rmm::exec_policy(stream), thrust::make_counting_iterator(0), @@ -359,7 +426,7 @@ std::unique_ptr replace(strings_column_view const& input, }(); auto const d_targets_offsets = targets_offsets.data(); - // compute the output count of each output string + // compute the number of string segments produced by replace in each string auto counts = rmm::device_uvector(strings_count, stream); thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), @@ -379,7 +446,7 @@ std::unique_ptr replace(strings_column_view const& input, // build a vector of all the positions for all the strings auto indices = rmm::device_uvector(total_strings, stream); auto d_indices = indices.data(); - auto d_sizes = counts.data(); + auto d_sizes = counts.data(); // reusing this vector to hold output sizes now thrust::for_each_n( rmm::exec_policy(stream), thrust::make_counting_iterator(0), diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index 1048334acc0..72300dbb851 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From a38ac59c29f34c5cd605f942730e3621d8b3abe3 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Tue, 28 Feb 2023 17:46:41 -0500 Subject: [PATCH 3/9] fix style check --- cpp/benchmarks/string/replace.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/benchmarks/string/replace.cpp b/cpp/benchmarks/string/replace.cpp index 2794acd5945..cb570020f0e 100644 --- a/cpp/benchmarks/string/replace.cpp +++ b/cpp/benchmarks/string/replace.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 5413bbd7473fb248d3e0d7e56915fcd99059eaff Mon Sep 17 00:00:00 2001 From: David Wendt Date: Mon, 6 Mar 2023 12:47:24 -0500 Subject: [PATCH 4/9] remove temporary gtest --- cpp/src/strings/replace/multi.cu | 11 +++++------ cpp/tests/strings/replace_tests.cpp | 28 ++++++---------------------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index fa8727cddcb..53b9c33f3ee 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -45,8 +45,6 @@ #include #include -#include - namespace cudf { namespace strings { namespace detail { @@ -253,7 +251,7 @@ struct replace_multi_parallel_fn { * @brief Used by the copy-if function to produce target_pair objects * * Using an inplace lambda caused a runtime crash in thrust::copy_if - * (this happens sometimes with passing device lambdas to thrust algorithms) + * (this happens sometimes when passing device lambdas to thrust algorithms) */ struct pair_generator { __device__ target_pair operator()(int idx) const @@ -272,7 +270,8 @@ struct copy_if_fn { /** * @brief Function logic for the replace_multi API. * - * This will perform the multi-replace operation on each string. + * Performs the multi-replace operation with a thread per string. + * This performs best on smaller strings. @see AVG_CHAR_BYTES_THRESHOLD */ struct replace_multi_fn { column_device_view const d_strings; @@ -380,7 +379,7 @@ std::unique_ptr replace(strings_column_view const& input, auto targets_positions = rmm::device_uvector(target_count, stream); auto d_positions = targets_positions.data(); - auto copy_itr = + auto const copy_itr = cudf::detail::make_counting_transform_iterator(0, pair_generator{fn, chars_bytes}); auto const copy_end = thrust::copy_if( rmm::exec_policy(stream), copy_itr, copy_itr + chars_bytes, d_positions, copy_if_fn{}); @@ -389,7 +388,7 @@ std::unique_ptr replace(strings_column_view const& input, auto const targets_offsets = [&] { auto string_indices = rmm::device_uvector(target_count, stream); - auto pos_itr = cudf::detail::make_counting_transform_iterator( + auto const pos_itr = cudf::detail::make_counting_transform_iterator( 0, [d_positions] __device__(auto idx) -> size_type { return d_positions[idx].first; }); auto pos_count = std::distance(d_positions, copy_end); diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index 72300dbb851..3c85f4c8528 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -173,22 +173,6 @@ TEST_F(StringsReplaceTest, ReplaceTargetOverlap) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } -TEST_F(StringsReplaceTest, ReplaceTargetOverlap2) -{ - auto input = cudf::test::strings_column_wrapper({"banana", "nanananananana"}); - auto strings_view = cudf::strings_column_view(input); - - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); - - auto results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("nana"), cudf::string_scalar(""), -1, stream, mr); - cudf::test::print(results->view()); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("nana"), cudf::string_scalar(""), -1, stream, mr); - cudf::test::print(results->view()); -} - TEST_F(StringsReplaceTest, ReplaceTargetOverlapsStrings) { auto input = build_corpus(); @@ -373,19 +357,19 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) {1, 1, 1, 1, 0, 1}); auto strings_view = cudf::strings_column_view(input); - cudf::test::strings_column_wrapper targets({"901", "bananá", "ápple"}); + cudf::test::strings_column_wrapper targets({"78901", "bananá", "ápple", "78"}); auto targets_view = cudf::strings_column_view(targets); { - cudf::test::strings_column_wrapper repls({"x", "PEAR", "avocado"}); + cudf::test::strings_column_wrapper repls({"x", "PEAR", "avocado", "$$"}); auto repls_view = cudf::strings_column_view(repls); auto results = cudf::strings::replace(strings_view, targets_view, repls_view); cudf::test::strings_column_wrapper expected( {"This string needs to be very long to trigger the long-replace internal functions.", - "012345678x2345678x2345678x2345678x2345678x2345678x2345678x23456789", - "012345678x2345678x2345678x2345678x2345678x2345678x2345678x23456789", + "0123456x23456x23456x23456x23456x23456x23456x23456$$9", + "0123456x23456x23456x23456x23456x23456x23456x23456$$9", "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR", "", ""}, @@ -401,8 +385,8 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) cudf::test::strings_column_wrapper expected( {"This string needs to be very long to trigger the long-replace internal functions.", - "012345678*2345678*2345678*2345678*2345678*2345678*2345678*23456789", - "012345678*2345678*2345678*2345678*2345678*2345678*2345678*23456789", + "0123456*23456*23456*23456*23456*23456*23456*23456*9", + "0123456*23456*23456*23456*23456*23456*23456*23456*9", "Test string for overlap check: banana* * ** ban* * *", "", ""}, From 1042d7be7db56d431b037345bb0e38e74888ce46 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Tue, 7 Mar 2023 15:23:06 -0500 Subject: [PATCH 5/9] change memset to uninitialized_fill --- cpp/src/strings/replace/multi.cu | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index 53b9c33f3ee..6c04fb4d8c1 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -57,7 +57,8 @@ namespace { * the character-parallel function is used. * Otherwise, a regular string-parallel function is used. * - * This value was found using the replace-multi benchmark results. + * This value was found using the replace-multi benchmark results using an + * RTX A6000. */ constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 256; @@ -115,7 +116,7 @@ struct replace_multi_parallel_fn { thrust::upper_bound(thrust::seq, d_offsets, d_offsets + d_strings.size(), idx); str_idx = thrust::distance(d_offsets, idx_itr) - 1; } - auto d_str = get_string(str_idx - d_offsets[0]); + auto const d_str = get_string(str_idx - d_offsets[0]); if ((d_chars + d_tgt.size_bytes()) <= (d_str.data() + d_str.size_bytes())) { return t; } } } @@ -200,7 +201,7 @@ struct replace_multi_parallel_fn { } auto const d_str_end = d_str.data() + d_str.size_bytes(); - auto const base_ptr = get_base_ptr(); //+ delim_size - 1; + auto const base_ptr = get_base_ptr(); auto const targets_positions = cudf::device_span( d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]); @@ -283,7 +284,7 @@ struct replace_multi_fn { __device__ void operator()(size_type idx) { if (d_strings.is_null(idx)) { - if (!d_chars) d_offsets[idx] = 0; + if (!d_chars) { d_offsets[idx] = 0; } return; } auto const d_str = d_strings.element(idx); @@ -330,7 +331,7 @@ std::unique_ptr replace(strings_column_view const& input, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - if (input.is_empty()) return make_empty_column(type_id::STRING); + if (input.is_empty()) { return make_empty_column(type_id::STRING); } CUDF_EXPECTS(((targets.size() > 0) && (targets.null_count() == 0)), "Parameters targets must not be empty and must not have nulls"); CUDF_EXPECTS(((repls.size() > 0) && (repls.null_count() == 0)), @@ -404,8 +405,8 @@ std::unique_ptr replace(strings_column_view const& input, auto d_targets_offsets = targets_offsets.data(); // memset to zero-out the target counts for any null-entries or strings with no targets - CUDF_CUDA_TRY(cudaMemsetAsync( - d_targets_offsets, 0, targets_offsets.size() * sizeof(size_type), stream.value())); + thrust::uninitialized_fill( + rmm::exec_policy(stream), targets_offsets.begin(), targets_offsets.end(), 0); // next, count the number of targets per string auto d_string_indices = string_indices.data(); From 2f1c9559de5754f47f44acdedcee537b0004492a Mon Sep 17 00:00:00 2001 From: David Wendt Date: Tue, 7 Mar 2023 17:19:38 -0500 Subject: [PATCH 6/9] refactor detail function into two sub-functions --- cpp/src/strings/replace/multi.cu | 187 +++++++++++++++++-------------- 1 file changed, 102 insertions(+), 85 deletions(-) diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index 6c04fb4d8c1..1deb8e55017 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -268,94 +268,14 @@ struct copy_if_fn { __device__ bool operator()(target_pair pos) { return pos.second >= 0; } }; -/** - * @brief Function logic for the replace_multi API. - * - * Performs the multi-replace operation with a thread per string. - * This performs best on smaller strings. @see AVG_CHAR_BYTES_THRESHOLD - */ -struct replace_multi_fn { - column_device_view const d_strings; - column_device_view const d_targets; - column_device_view const d_repls; - int32_t* d_offsets{}; - char* d_chars{}; - - __device__ void operator()(size_type idx) - { - if (d_strings.is_null(idx)) { - if (!d_chars) { d_offsets[idx] = 0; } - return; - } - auto const d_str = d_strings.element(idx); - char const* in_ptr = d_str.data(); - - size_type bytes = d_str.size_bytes(); - size_type spos = 0; - size_type lpos = 0; - char* out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; - - // check each character against each target - while (spos < d_str.size_bytes()) { - for (int tgt_idx = 0; tgt_idx < d_targets.size(); ++tgt_idx) { - auto const d_tgt = d_targets.element(tgt_idx); - if ((d_tgt.size_bytes() <= (d_str.size_bytes() - spos)) && // check fit - (d_tgt.compare(in_ptr + spos, d_tgt.size_bytes()) == 0)) // and match - { - auto const d_repl = (d_repls.size() == 1) ? d_repls.element(0) - : d_repls.element(tgt_idx); - bytes += d_repl.size_bytes() - d_tgt.size_bytes(); - if (out_ptr) { - out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); - out_ptr = copy_string(out_ptr, d_repl); - lpos = spos + d_tgt.size_bytes(); - } - spos += d_tgt.size_bytes() - 1; - break; - } - } - ++spos; - } - if (out_ptr) // copy remainder - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - else - d_offsets[idx] = bytes; - } -}; - -} // namespace - -std::unique_ptr replace(strings_column_view const& input, - strings_column_view const& targets, - strings_column_view const& repls, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr replace_character_parallel(strings_column_view const& input, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { - if (input.is_empty()) { return make_empty_column(type_id::STRING); } - CUDF_EXPECTS(((targets.size() > 0) && (targets.null_count() == 0)), - "Parameters targets must not be empty and must not have nulls"); - CUDF_EXPECTS(((repls.size() > 0) && (repls.null_count() == 0)), - "Parameters repls must not be empty and must not have nulls"); - if (repls.size() > 1) - CUDF_EXPECTS(repls.size() == targets.size(), "Sizes for targets and repls must match"); - auto d_strings = column_device_view::create(input.parent(), stream); - if (input.size() == input.null_count() || - ((input.chars_size() / (input.size() - input.null_count())) < AVG_CHAR_BYTES_THRESHOLD)) { - auto d_targets = column_device_view::create(targets.parent(), stream); - auto d_replacements = column_device_view::create(repls.parent(), stream); - - auto children = cudf::strings::detail::make_strings_children( - replace_multi_fn{*d_strings, *d_targets, *d_replacements}, input.size(), stream, mr); - - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); - } - auto const strings_count = input.size(); auto const chars_bytes = cudf::detail::get_value(input.offsets(), input.offset() + strings_count, stream) - @@ -472,6 +392,103 @@ std::unique_ptr replace(strings_column_view const& input, copy_bitmask(input.parent(), stream, mr)); } +/** + * @brief Function logic for the replace_string_parallel + * + * Performs the multi-replace operation with a thread per string. + * This performs best on smaller strings. @see AVG_CHAR_BYTES_THRESHOLD + */ +struct replace_multi_fn { + column_device_view const d_strings; + column_device_view const d_targets; + column_device_view const d_repls; + int32_t* d_offsets{}; + char* d_chars{}; + + __device__ void operator()(size_type idx) + { + if (d_strings.is_null(idx)) { + if (!d_chars) { d_offsets[idx] = 0; } + return; + } + auto const d_str = d_strings.element(idx); + char const* in_ptr = d_str.data(); + + size_type bytes = d_str.size_bytes(); + size_type spos = 0; + size_type lpos = 0; + char* out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; + + // check each character against each target + while (spos < d_str.size_bytes()) { + for (int tgt_idx = 0; tgt_idx < d_targets.size(); ++tgt_idx) { + auto const d_tgt = d_targets.element(tgt_idx); + if ((d_tgt.size_bytes() <= (d_str.size_bytes() - spos)) && // check fit + (d_tgt.compare(in_ptr + spos, d_tgt.size_bytes()) == 0)) // and match + { + auto const d_repl = (d_repls.size() == 1) ? d_repls.element(0) + : d_repls.element(tgt_idx); + bytes += d_repl.size_bytes() - d_tgt.size_bytes(); + if (out_ptr) { + out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); + out_ptr = copy_string(out_ptr, d_repl); + lpos = spos + d_tgt.size_bytes(); + } + spos += d_tgt.size_bytes() - 1; + break; + } + } + ++spos; + } + if (out_ptr) // copy remainder + memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); + else + d_offsets[idx] = bytes; + } +}; + +std::unique_ptr replace_string_parallel(strings_column_view const& input, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto d_strings = column_device_view::create(input.parent(), stream); + auto d_targets = column_device_view::create(targets.parent(), stream); + auto d_replacements = column_device_view::create(repls.parent(), stream); + + auto children = cudf::strings::detail::make_strings_children( + replace_multi_fn{*d_strings, *d_targets, *d_replacements}, input.size(), stream, mr); + + return make_strings_column(input.size(), + std::move(children.first), + std::move(children.second), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)); +} + +} // namespace + +std::unique_ptr replace(strings_column_view const& input, + strings_column_view const& targets, + strings_column_view const& repls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + if (input.is_empty()) { return make_empty_column(type_id::STRING); } + CUDF_EXPECTS(((targets.size() > 0) && (targets.null_count() == 0)), + "Parameters targets must not be empty and must not have nulls"); + CUDF_EXPECTS(((repls.size() > 0) && (repls.null_count() == 0)), + "Parameters repls must not be empty and must not have nulls"); + if (repls.size() > 1) + CUDF_EXPECTS(repls.size() == targets.size(), "Sizes for targets and repls must match"); + + return (input.size() == input.null_count() || + ((input.chars_size() / (input.size() - input.null_count())) < AVG_CHAR_BYTES_THRESHOLD)) + ? replace_string_parallel(input, targets, repls, stream, mr) + : replace_character_parallel(input, targets, repls, stream, mr); +} + } // namespace detail // external API From 7ce40ed4d201b4aaa766cd3345b1c03fae959039 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 8 Mar 2023 21:03:35 -0500 Subject: [PATCH 7/9] add comment about string lengths in gtest --- cpp/tests/strings/replace_tests.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index 3c85f4c8528..82d0da83f4c 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -347,6 +347,8 @@ TEST_F(StringsReplaceTest, ReplaceMulti) TEST_F(StringsReplaceTest, ReplaceMultiLong) { + // The length of the strings are to trigger the code path governed by the AVG_CHAR_BYTES_THRESHOLD + // setting in the multi.cu. auto input = cudf::test::strings_column_wrapper( {"This string needs to be very long to trigger the long-replace internal functions.", "01234567890123456789012345678901234567890123456789012345678901234567890123456789", From 6d2043fcdc6cd220af908fe63a6135f67ddcdefc Mon Sep 17 00:00:00 2001 From: David Wendt Date: Thu, 9 Mar 2023 07:41:02 -0500 Subject: [PATCH 8/9] fix replace-entire-string case --- cpp/src/strings/replace/multi.cu | 14 ++--- cpp/tests/strings/replace_tests.cpp | 93 +++++++++++++++++++++++++---- 2 files changed, 86 insertions(+), 21 deletions(-) diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index 1deb8e55017..1168ddd8613 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -191,17 +191,11 @@ struct replace_multi_parallel_fn { { if (!is_valid(idx)) { return 0; } - auto const d_output = d_all_strings + d_offsets[idx]; - auto const d_output_count = d_offsets[idx + 1] - d_offsets[idx]; + auto const d_output = d_all_strings + d_offsets[idx]; + auto const d_str = get_string(idx); + auto const d_str_end = d_str.data() + d_str.size_bytes(); + auto const base_ptr = get_base_ptr(); - auto const d_str = get_string(idx); - if (d_output_count == 1) { - d_output[0] = string_index_pair{d_str.data(), d_str.size_bytes()}; - return d_str.size_bytes(); - } - - auto const d_str_end = d_str.data() + d_str.size_bytes(); - auto const base_ptr = get_base_ptr(); auto const targets_positions = cudf::device_span( d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]); diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index 82d0da83f4c..85185b2deab 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -350,16 +350,31 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) // The length of the strings are to trigger the code path governed by the AVG_CHAR_BYTES_THRESHOLD // setting in the multi.cu. auto input = cudf::test::strings_column_wrapper( - {"This string needs to be very long to trigger the long-replace internal functions.", - "01234567890123456789012345678901234567890123456789012345678901234567890123456789", - "01234567890123456789012345678901234567890123456789012345678901234567890123456789", + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" + "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" + "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" + "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "2345678901234567890123456789", + "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" + "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" + "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" + "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "2345678901234567890123456789", + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá", "", ""}, {1, 1, 1, 1, 0, 1}); auto strings_view = cudf::strings_column_view(input); - cudf::test::strings_column_wrapper targets({"78901", "bananá", "ápple", "78"}); + auto targets = cudf::test::strings_column_wrapper({"78901", "bananá", "ápple", "78"}); auto targets_view = cudf::strings_column_view(targets); { @@ -369,9 +384,20 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) auto results = cudf::strings::replace(strings_view, targets_view, repls_view); cudf::test::strings_column_wrapper expected( - {"This string needs to be very long to trigger the long-replace internal functions.", - "0123456x23456x23456x23456x23456x23456x23456x23456$$9", - "0123456x23456x23456x23456x23456x23456x23456x23456$$9", + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " + "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR", "", ""}, @@ -386,10 +412,55 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) auto results = cudf::strings::replace(strings_view, targets_view, repls_view); cudf::test::strings_column_wrapper expected( - {"This string needs to be very long to trigger the long-replace internal functions.", - "0123456*23456*23456*23456*23456*23456*23456*23456*9", - "0123456*23456*23456*23456*23456*23456*23456*23456*9", - "Test string for overlap check: banana* * ** ban* * *", + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "0123456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*9", + "0123456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" + "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*9", + "Test string for overlap check: banana* * ** ban* * * Test string for overlap check: " + "banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * * Test string for " + "overlap check: banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * *", + "", + ""}, + {1, 1, 1, 1, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } + + { + targets = + cudf::test::strings_column_wrapper({"01234567890123456789012345678901234567890123456789012345" + "6789012345678901234567890123456789012" + "34567890123456789012345678901234567890123456789012345678" + "9012345678901234567890123456789012345" + "67890123456789012345678901234567890123456789012345678901" + "2345678901234567890123456789012345678" + "90123456789012345678901234567890123456789012345678901234" + "5678901234567890123456789012345678901" + "2345678901234567890123456789", + "78"}); + targets_view = cudf::strings_column_view(targets); + auto repls = cudf::test::strings_column_wrapper({""}); + auto repls_view = cudf::strings_column_view(repls); + + auto results = cudf::strings::replace(strings_view, targets_view, repls_view); + + cudf::test::strings_column_wrapper expected( + {"This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions. " + "This string needs to be very long to trigger the long-replace internal functions.", + "", + "", + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " + "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá", "", ""}, {1, 1, 1, 1, 0, 1}); From 242a56e8a7250361aae70660c903505a9d1b5ee9 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Thu, 16 Mar 2023 08:37:57 -0400 Subject: [PATCH 9/9] remove unneeded const --- cpp/src/strings/replace/multi.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index 1168ddd8613..92ace4e7bc7 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -230,8 +230,8 @@ struct replace_multi_parallel_fn { } replace_multi_parallel_fn(column_device_view const& d_strings, - device_span const d_targets, - device_span const d_replacements) + device_span d_targets, + device_span d_replacements) : d_strings(d_strings), d_targets{d_targets}, d_replacements{d_replacements} { }