From 6f6e521257dce5732eea7b6b9d56243f8b0a69cc Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 22 Feb 2024 08:58:35 -0500 Subject: [PATCH] Split out strings/replace.cu and rework its gtests (#15054) Splitting out changes in PR #14824 to make it easier to review. The changes here simply move `replace_slice()` and `replace_nulls()` from `replace.cu` into their own source files. The detail functions have been simplified removing the template argument that was only needed for unit tests. The gtests were reworked to force calling either row-parallel or character-parallel based on the data input instead of being executed directly. This simplified the internal logic which had duplicate parameter checking. The `cudf::strings::detail::replace_nulls()` is also fixed to use the appropriate `make_offsets_child_column` utitlity. The PR #14824 changes will add large strings support to `cudf::strings::replace()`. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Nghia Truong (https://github.com/ttnghia) - https://github.com/nvdbaranec URL: https://github.com/rapidsai/cudf/pull/15054 --- cpp/CMakeLists.txt | 2 + cpp/include/cudf/strings/detail/replace.hpp | 45 ++-- cpp/src/strings/replace/replace.cu | 190 +--------------- cpp/src/strings/replace/replace_nulls.cu | 81 +++++++ cpp/src/strings/replace/replace_slice.cu | 117 ++++++++++ cpp/tests/strings/replace_tests.cpp | 239 +++++++++++--------- 6 files changed, 352 insertions(+), 322 deletions(-) create mode 100644 cpp/src/strings/replace/replace_nulls.cu create mode 100644 cpp/src/strings/replace/replace_slice.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 078de27f0ea..58a43c1def1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -587,7 +587,9 @@ add_library( src/strings/replace/multi.cu src/strings/replace/multi_re.cu src/strings/replace/replace.cu + src/strings/replace/replace_nulls.cu src/strings/replace/replace_re.cu + src/strings/replace/replace_slice.cu src/strings/reverse.cu src/strings/scan/scan_inclusive.cu src/strings/search/findall.cu diff --git a/cpp/include/cudf/strings/detail/replace.hpp b/cpp/include/cudf/strings/detail/replace.hpp index aa6fb2feb3d..28027291b28 100644 --- a/cpp/include/cudf/strings/detail/replace.hpp +++ b/cpp/include/cudf/strings/detail/replace.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,10 @@ namespace cudf { namespace strings { namespace detail { -/** - * @brief The type of algorithm to use for a replace operation. - */ -enum class replace_algorithm { - AUTO, ///< Automatically choose the algorithm based on heuristics - ROW_PARALLEL, ///< Row-level parallelism - CHAR_PARALLEL ///< Character-level parallelism -}; - /** * @copydoc cudf::strings::replace(strings_column_view const&, string_scalar const&, - * string_scalar const&, int32_t, rmm::mr::device_memory_resource*) - * - * @tparam alg Replacement algorithm to use - * @param[in] stream CUDA stream used for device memory operations and kernel launches. + * string_scalar const&, int32_t, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) */ -template std::unique_ptr replace(strings_column_view const& strings, string_scalar const& target, string_scalar const& repl, @@ -50,24 +37,9 @@ std::unique_ptr replace(strings_column_view const& strings, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); -/** - * @copydoc cudf::strings::replace_slice(strings_column_view const&, string_scalar const&, - * size_type. size_type, rmm::mr::device_memory_resource*) - * - * @param[in] stream CUDA stream used for device memory operations and kernel launches. - */ -std::unique_ptr replace_slice(strings_column_view const& strings, - string_scalar const& repl, - size_type start, - size_type stop, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr); - /** * @copydoc cudf::strings::replace(strings_column_view const&, strings_column_view const&, - * strings_column_view const&, rmm::mr::device_memory_resource*) - * - * @param[in] stream CUDA stream used for device memory operations and kernel launches. + * strings_column_view const&, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) */ std::unique_ptr replace(strings_column_view const& strings, strings_column_view const& targets, @@ -98,6 +70,17 @@ std::unique_ptr replace_nulls(strings_column_view const& strings, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); +/** + * @copydoc cudf::strings::replace_slice(strings_column_view const&, string_scalar const&, + * size_type, size_type, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) + */ +std::unique_ptr replace_slice(strings_column_view const& strings, + string_scalar const& repl, + size_type start, + size_type stop, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); + } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/replace/replace.cu b/cpp/src/strings/replace/replace.cu index d68ec84f68c..2d255e57686 100644 --- a/cpp/src/strings/replace/replace.cu +++ b/cpp/src/strings/replace/replace.cu @@ -542,17 +542,12 @@ std::unique_ptr replace_row_parallel(strings_column_view const& strings, } // namespace -/** - * @copydoc cudf::strings::detail::replace(strings_column_view const&, string_scalar const&, - * string_scalar const&, int32_t, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) - */ -template <> -std::unique_ptr replace(strings_column_view const& strings, - string_scalar const& target, - string_scalar const& repl, - int32_t maxrepl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr replace(strings_column_view const& strings, + string_scalar const& target, + string_scalar const& repl, + int32_t maxrepl, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { if (strings.is_empty()) return make_empty_column(type_id::STRING); if (maxrepl == 0) return std::make_unique(strings.parent(), stream, mr); @@ -584,168 +579,6 @@ std::unique_ptr replace(strings_column_view con strings, chars_start, chars_end, d_target, d_repl, maxrepl, stream, mr); } -template <> -std::unique_ptr replace( - strings_column_view const& strings, - string_scalar const& target, - string_scalar const& repl, - int32_t maxrepl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - if (maxrepl == 0) return std::make_unique(strings.parent(), stream, mr); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid."); - CUDF_EXPECTS(target.size() > 0, "Parameter target must not be empty string."); - - string_view d_target(target.data(), target.size()); - string_view d_repl(repl.data(), repl.size()); - - // determine range of characters in the base column - auto const strings_count = strings.size(); - auto const offset_count = strings_count + 1; - auto const d_offsets = strings.offsets_begin(); - size_type chars_start = (strings.offset() == 0) ? 0 - : cudf::detail::get_value( - strings.offsets(), strings.offset(), stream); - size_type chars_end = (offset_count == strings.offsets().size()) - ? strings.chars_size(stream) - : cudf::detail::get_value( - strings.offsets(), strings.offset() + strings_count, stream); - return replace_char_parallel( - strings, chars_start, chars_end, d_target, d_repl, maxrepl, stream, mr); -} - -template <> -std::unique_ptr replace( - strings_column_view const& strings, - string_scalar const& target, - string_scalar const& repl, - int32_t maxrepl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - if (maxrepl == 0) return std::make_unique(strings.parent(), stream, mr); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid."); - CUDF_EXPECTS(target.size() > 0, "Parameter target must not be empty string."); - - string_view d_target(target.data(), target.size()); - string_view d_repl(repl.data(), repl.size()); - return replace_row_parallel(strings, d_target, d_repl, maxrepl, stream, mr); -} - -namespace { -/** - * @brief Function logic for the replace_slice API. - * - * This will perform a replace_slice operation on each string. - */ -struct replace_slice_fn { - column_device_view const d_strings; - string_view const d_repl; - size_type const start; - size_type const stop; - int32_t* d_offsets{}; - char* d_chars{}; - - __device__ void operator()(size_type idx) - { - if (d_strings.is_null(idx)) { - if (!d_chars) d_offsets[idx] = 0; - return; - } - auto const d_str = d_strings.element(idx); - auto const length = d_str.length(); - char const* in_ptr = d_str.data(); - auto const begin = d_str.byte_offset(((start < 0) || (start > length) ? length : start)); - auto const end = d_str.byte_offset(((stop < 0) || (stop > length) ? length : stop)); - - if (d_chars) { - char* out_ptr = d_chars + d_offsets[idx]; - - out_ptr = copy_and_increment(out_ptr, in_ptr, begin); // copy beginning - out_ptr = copy_string(out_ptr, d_repl); // insert replacement - out_ptr = copy_and_increment(out_ptr, // copy end - in_ptr + end, - d_str.size_bytes() - end); - } else { - d_offsets[idx] = d_str.size_bytes() + d_repl.size_bytes() - (end - begin); - } - } -}; - -} // namespace - -std::unique_ptr replace_slice(strings_column_view const& strings, - string_scalar const& repl, - size_type start, - size_type stop, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - if (stop > 0) CUDF_EXPECTS(start <= stop, "Parameter start must be less than or equal to stop."); - - string_view d_repl(repl.data(), repl.size()); - - auto d_strings = column_device_view::create(strings.parent(), stream); - - // this utility calls the given functor to build the offsets and chars columns - auto [offsets_column, chars_column] = cudf::strings::detail::make_strings_children( - replace_slice_fn{*d_strings, d_repl, start, stop}, strings.size(), stream, mr); - - return make_strings_column(strings.size(), - std::move(offsets_column), - std::move(chars_column->release().data.release()[0]), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); -} - -std::unique_ptr replace_nulls(strings_column_view const& strings, - string_scalar const& repl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - size_type strings_count = strings.size(); - if (strings_count == 0) return make_empty_column(type_id::STRING); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - - string_view d_repl(repl.data(), repl.size()); - - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_strings = *strings_column; - - // build offsets column - auto offsets_transformer_itr = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - cuda::proclaim_return_type([d_strings, d_repl] __device__(size_type idx) { - return d_strings.is_null(idx) ? d_repl.size_bytes() - : d_strings.element(idx).size_bytes(); - })); - auto [offsets_column, bytes] = cudf::detail::make_offsets_child_column( - offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); - auto d_offsets = offsets_column->view().data(); - - // build chars column - rmm::device_uvector chars(bytes, stream, mr); - auto d_chars = chars.data(); - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - [d_strings, d_repl, d_offsets, d_chars] __device__(size_type idx) { - string_view d_str = d_repl; - if (!d_strings.is_null(idx)) d_str = d_strings.element(idx); - memcpy(d_chars + d_offsets[idx], d_str.data(), d_str.size_bytes()); - }); - - return make_strings_column( - strings_count, std::move(offsets_column), chars.release(), 0, rmm::device_buffer{}); -} - } // namespace detail // external API @@ -761,16 +594,5 @@ std::unique_ptr replace(strings_column_view const& strings, return detail::replace(strings, target, repl, maxrepl, stream, mr); } -std::unique_ptr replace_slice(strings_column_view const& strings, - string_scalar const& repl, - size_type start, - size_type stop, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - CUDF_FUNC_RANGE(); - return detail::replace_slice(strings, repl, start, stop, stream, mr); -} - } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/replace/replace_nulls.cu b/cpp/src/strings/replace/replace_nulls.cu new file mode 100644 index 00000000000..26fb1c7819f --- /dev/null +++ b/cpp/src/strings/replace/replace_nulls.cu @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace cudf { +namespace strings { +namespace detail { + +std::unique_ptr replace_nulls(strings_column_view const& strings, + string_scalar const& repl, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + size_type strings_count = strings.size(); + if (strings_count == 0) return make_empty_column(type_id::STRING); + CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); + + string_view d_repl(repl.data(), repl.size()); + + auto strings_column = column_device_view::create(strings.parent(), stream); + auto d_strings = *strings_column; + + // build offsets column + auto offsets_transformer_itr = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([d_strings, d_repl] __device__(size_type idx) { + return d_strings.is_null(idx) ? d_repl.size_bytes() + : d_strings.element(idx).size_bytes(); + })); + auto [offsets_column, bytes] = cudf::strings::detail::make_offsets_child_column( + offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); + auto d_offsets = offsets_column->view().data(); + + // build chars column + rmm::device_uvector chars(bytes, stream, mr); + auto d_chars = chars.data(); + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + strings_count, + [d_strings, d_repl, d_offsets, d_chars] __device__(size_type idx) { + string_view d_str = d_repl; + if (!d_strings.is_null(idx)) d_str = d_strings.element(idx); + memcpy(d_chars + d_offsets[idx], d_str.data(), d_str.size_bytes()); + }); + + return make_strings_column( + strings_count, std::move(offsets_column), chars.release(), 0, rmm::device_buffer{}); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/replace_slice.cu b/cpp/src/strings/replace/replace_slice.cu new file mode 100644 index 00000000000..4321f78d2d5 --- /dev/null +++ b/cpp/src/strings/replace/replace_slice.cu @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace cudf { +namespace strings { +namespace detail { +namespace { +/** + * @brief Function logic for the replace_slice API. + * + * This will perform a replace_slice operation on each string. + */ +struct replace_slice_fn { + column_device_view const d_strings; + string_view const d_repl; + size_type const start; + size_type const stop; + size_type* d_offsets{}; + char* d_chars{}; + + __device__ void operator()(size_type idx) + { + if (d_strings.is_null(idx)) { + if (!d_chars) d_offsets[idx] = 0; + return; + } + auto const d_str = d_strings.element(idx); + auto const length = d_str.length(); + char const* in_ptr = d_str.data(); + auto const begin = d_str.byte_offset(((start < 0) || (start > length) ? length : start)); + auto const end = d_str.byte_offset(((stop < 0) || (stop > length) ? length : stop)); + + if (d_chars) { + char* out_ptr = d_chars + d_offsets[idx]; + + out_ptr = copy_and_increment(out_ptr, in_ptr, begin); // copy beginning + out_ptr = copy_string(out_ptr, d_repl); // insert replacement + out_ptr = copy_and_increment(out_ptr, // copy end + in_ptr + end, + d_str.size_bytes() - end); + } else { + d_offsets[idx] = d_str.size_bytes() + d_repl.size_bytes() - (end - begin); + } + } +}; + +} // namespace + +std::unique_ptr replace_slice(strings_column_view const& strings, + string_scalar const& repl, + size_type start, + size_type stop, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + if (strings.is_empty()) return make_empty_column(type_id::STRING); + CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); + if (stop > 0) CUDF_EXPECTS(start <= stop, "Parameter start must be less than or equal to stop."); + + string_view d_repl(repl.data(), repl.size()); + + auto d_strings = column_device_view::create(strings.parent(), stream); + + // this utility calls the given functor to build the offsets and chars columns + auto [offsets_column, chars_column] = cudf::strings::detail::make_strings_children( + replace_slice_fn{*d_strings, d_repl, start, stop}, strings.size(), stream, mr); + + return make_strings_column(strings.size(), + std::move(offsets_column), + std::move(chars_column->release().data.release()[0]), + strings.null_count(), + cudf::detail::copy_bitmask(strings.parent(), stream, mr)); +} +} // namespace detail + +std::unique_ptr replace_slice(strings_column_view const& strings, + string_scalar const& repl, + size_type start, + size_type stop, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::replace_slice(strings, repl, start, stop, stream, mr); +} + +} // namespace strings +} // namespace cudf diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index f04bb832f09..726d9f95c7d 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,17 +20,12 @@ #include #include -#include +#include #include #include -#include -#include - #include -using algorithm = cudf::strings::detail::replace_algorithm; - struct StringsReplaceTest : public cudf::test::BaseFixture { cudf::test::strings_column_wrapper build_corpus() { @@ -47,6 +42,13 @@ struct StringsReplaceTest : public cudf::test::BaseFixture { h_strings.end(), thrust::make_transform_iterator(h_strings.begin(), [](auto str) { return str != nullptr; })); } + + std::unique_ptr build_large(cudf::column_view const& first, + cudf::column_view const& remaining) + { + return cudf::strings::concatenate(cudf::table_view( + {first, remaining, remaining, remaining, remaining, remaining, remaining, remaining})); + } }; TEST_F(StringsReplaceTest, Replace) @@ -64,26 +66,23 @@ TEST_F(StringsReplaceTest, Replace) cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); + auto target = cudf::string_scalar("the "); + auto replacement = cudf::string_scalar("++++ "); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("the "), cudf::string_scalar("++++ ")); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar("++++ "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar("++++ "), -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceReplLimit) { auto input = build_corpus(); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); // only remove the first occurrence of 'the ' std::vector h_expected{"quick brown fox jumps over the lazy dog", @@ -95,15 +94,16 @@ TEST_F(StringsReplaceTest, ReplaceReplLimit) nullptr}; cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("the "), cudf::string_scalar(""), 1); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar(""), 1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar(""), 1, stream, mr); + auto target = cudf::string_scalar("the "); + auto replacement = cudf::string_scalar(""); + auto results = cudf::strings::replace(strings_view, target, replacement, 1); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, input); + results = cudf::strings::replace(strings_view, target, replacement, 1); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceReplLimitInputSliced) @@ -119,22 +119,28 @@ TEST_F(StringsReplaceTest, ReplaceReplLimitInputSliced) nullptr}; cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); std::vector slice_indices{0, 2, 2, 3, 3, 7}; auto sliced_strings = cudf::slice(input, slice_indices); auto sliced_expected = cudf::slice(expected, slice_indices); + + auto input_large = build_large(input, input); + auto expected_large = build_large(expected, input); + + auto sliced_large = cudf::slice(input_large->view(), slice_indices); + auto sliced_expected_large = cudf::slice(expected_large->view(), slice_indices); + + auto target = cudf::string_scalar(" "); + auto replacement = cudf::string_scalar("--"); + for (size_t i = 0; i < sliced_strings.size(); ++i) { auto strings_view = cudf::strings_column_view(sliced_strings[i]); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected[i]); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected[i]); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement, 2); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected[i]); + + strings_view = cudf::strings_column_view(sliced_large[i]); + results = + cudf::strings::replace(strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected_large[i]); } } @@ -158,68 +164,56 @@ TEST_F(StringsReplaceTest, ReplaceTargetOverlap) cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); + auto target = cudf::string_scalar("+++"); + auto replacement = cudf::string_scalar("plus "); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("+++"), cudf::string_scalar("plus ")); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+++"), cudf::string_scalar("plus "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+++"), cudf::string_scalar("plus "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + auto input_large = build_large(input->view(), input->view()); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceTargetOverlapsStrings) { auto input = build_corpus(); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); // replace all occurrences of 'dogthe' with '+' + auto target = cudf::string_scalar("dogthe"); + auto replacement = cudf::string_scalar("+"); + // should not replace anything unless it incorrectly matches across a string boundary - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("dogthe"), cudf::string_scalar("+")); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("dogthe"), cudf::string_scalar("+"), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("dogthe"), cudf::string_scalar("+"), -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); + + auto input_large = cudf::strings::concatenate( + cudf::table_view({input, input, input, input, input, input, input, input}), + cudf::string_scalar(" ")); + strings_view = cudf::strings_column_view(input_large->view()); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *input_large); } -TEST_F(StringsReplaceTest, ReplaceNullInput) +TEST_F(StringsReplaceTest, ReplaceAllNullInput) { std::vector h_null_strings(128); auto input = cudf::test::strings_column_wrapper( h_null_strings.begin(), h_null_strings.end(), thrust::make_constant_iterator(false)); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); - // replace all occurrences of '+' with '' - // should not replace anything as input is all null auto results = cudf::strings::replace(strings_view, cudf::string_scalar("+"), cudf::string_scalar("")); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+"), cudf::string_scalar(""), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+"), cudf::string_scalar(""), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); } TEST_F(StringsReplaceTest, ReplaceEndOfString) { auto input = build_corpus(); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); // replace all occurrences of 'in' with ' ' std::vector h_expected{"the quick brown fox jumps over the lazy dog", @@ -233,39 +227,56 @@ TEST_F(StringsReplaceTest, ReplaceEndOfString) cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("in"), cudf::string_scalar(" ")); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + auto target = cudf::string_scalar("in"); + auto replacement = cudf::string_scalar(" "); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("in"), cudf::string_scalar(" "), -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("in"), cudf::string_scalar(" "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceAdjacentMultiByteTarget) { - auto input = cudf::test::strings_column_wrapper({"ééééééé", "eéeéeée", "eeeeeee"}); + auto input = cudf::test::strings_column_wrapper({"ééééééééééééééééééééé", + "eéeéeéeeéeéeéeeéeéeée", + "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"}); auto strings_view = cudf::strings_column_view(input); // replace all occurrences of 'é' with 'e' - cudf::test::strings_column_wrapper expected({"eeeeeee", "eeeeeee", "eeeeeee"}); + cudf::test::strings_column_wrapper expected({"eeeeeeeeeeeeeeeeeeeee", + "eeeeeeeeeeeeeeeeeeeee", + "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"}); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); + auto target = cudf::string_scalar("é"); + auto replacement = cudf::string_scalar("e"); - auto target = cudf::string_scalar("é", true, stream); - auto repl = cudf::string_scalar("e", true, stream); - auto results = cudf::strings::replace(strings_view, target, repl); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, target, repl, -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, target, repl, -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); +} + +TEST_F(StringsReplaceTest, ReplaceErrors) +{ + auto input = cudf::test::strings_column_wrapper({"this column intentionally left blank"}); + + auto target = cudf::string_scalar(" "); + auto replacement = cudf::string_scalar("_"); + auto null_input = cudf::string_scalar("", false); + auto empty_input = cudf::string_scalar(""); + auto sv = cudf::strings_column_view(input); + + EXPECT_THROW(cudf::strings::replace(sv, target, null_input), cudf::logic_error); + EXPECT_THROW(cudf::strings::replace(sv, null_input, replacement), cudf::logic_error); + EXPECT_THROW(cudf::strings::replace(sv, empty_input, replacement), cudf::logic_error); } TEST_F(StringsReplaceTest, ReplaceSlice) @@ -369,22 +380,30 @@ TEST_F(StringsReplaceTest, ReplaceMulti) TEST_F(StringsReplaceTest, ReplaceMultiLong) { - // The length of the strings are to trigger the code path governed by the AVG_CHAR_BYTES_THRESHOLD - // setting in the multi.cu. + // The length of the strings are to trigger the code path governed by the + // AVG_CHAR_BYTES_THRESHOLD setting in the multi.cu. auto input = cudf::test::strings_column_wrapper( {"This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions.", - "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" - "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" - "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" - "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + "12" + "3456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123" + "45" + "6789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456" + "78" + "9012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + "01" "2345678901234567890123456789", - "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" - "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" - "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" - "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + "12" + "3456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123" + "45" + "6789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456" + "78" + "9012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + "01" "2345678901234567890123456789", "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " @@ -410,11 +429,15 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions.", - "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" - "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x234" + "56" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x2345" + "6x" "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", - "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" - "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x234" + "56" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x2345" + "6x" "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " @@ -445,8 +468,10 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*9", "Test string for overlap check: banana* * ** ban* * * Test string for overlap check: " - "banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * * Test string for " - "overlap check: banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * *", + "banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * * Test string " + "for " + "overlap check: banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * " + "*", "", ""}, {1, 1, 1, 1, 0, 1});