diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 078de27f0ea..58a43c1def1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -587,7 +587,9 @@ add_library( src/strings/replace/multi.cu src/strings/replace/multi_re.cu src/strings/replace/replace.cu + src/strings/replace/replace_nulls.cu src/strings/replace/replace_re.cu + src/strings/replace/replace_slice.cu src/strings/reverse.cu src/strings/scan/scan_inclusive.cu src/strings/search/findall.cu diff --git a/cpp/include/cudf/strings/detail/replace.hpp b/cpp/include/cudf/strings/detail/replace.hpp index aa6fb2feb3d..28027291b28 100644 --- a/cpp/include/cudf/strings/detail/replace.hpp +++ b/cpp/include/cudf/strings/detail/replace.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,10 @@ namespace cudf { namespace strings { namespace detail { -/** - * @brief The type of algorithm to use for a replace operation. - */ -enum class replace_algorithm { - AUTO, ///< Automatically choose the algorithm based on heuristics - ROW_PARALLEL, ///< Row-level parallelism - CHAR_PARALLEL ///< Character-level parallelism -}; - /** * @copydoc cudf::strings::replace(strings_column_view const&, string_scalar const&, - * string_scalar const&, int32_t, rmm::mr::device_memory_resource*) - * - * @tparam alg Replacement algorithm to use - * @param[in] stream CUDA stream used for device memory operations and kernel launches. + * string_scalar const&, int32_t, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) */ -template std::unique_ptr replace(strings_column_view const& strings, string_scalar const& target, string_scalar const& repl, @@ -50,24 +37,9 @@ std::unique_ptr replace(strings_column_view const& strings, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); -/** - * @copydoc cudf::strings::replace_slice(strings_column_view const&, string_scalar const&, - * size_type. size_type, rmm::mr::device_memory_resource*) - * - * @param[in] stream CUDA stream used for device memory operations and kernel launches. - */ -std::unique_ptr replace_slice(strings_column_view const& strings, - string_scalar const& repl, - size_type start, - size_type stop, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr); - /** * @copydoc cudf::strings::replace(strings_column_view const&, strings_column_view const&, - * strings_column_view const&, rmm::mr::device_memory_resource*) - * - * @param[in] stream CUDA stream used for device memory operations and kernel launches. + * strings_column_view const&, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) */ std::unique_ptr replace(strings_column_view const& strings, strings_column_view const& targets, @@ -98,6 +70,17 @@ std::unique_ptr replace_nulls(strings_column_view const& strings, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); +/** + * @copydoc cudf::strings::replace_slice(strings_column_view const&, string_scalar const&, + * size_type, size_type, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) + */ +std::unique_ptr replace_slice(strings_column_view const& strings, + string_scalar const& repl, + size_type start, + size_type stop, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); + } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/replace/replace.cu b/cpp/src/strings/replace/replace.cu index d68ec84f68c..2d255e57686 100644 --- a/cpp/src/strings/replace/replace.cu +++ b/cpp/src/strings/replace/replace.cu @@ -542,17 +542,12 @@ std::unique_ptr replace_row_parallel(strings_column_view const& strings, } // namespace -/** - * @copydoc cudf::strings::detail::replace(strings_column_view const&, string_scalar const&, - * string_scalar const&, int32_t, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) - */ -template <> -std::unique_ptr replace(strings_column_view const& strings, - string_scalar const& target, - string_scalar const& repl, - int32_t maxrepl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr replace(strings_column_view const& strings, + string_scalar const& target, + string_scalar const& repl, + int32_t maxrepl, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { if (strings.is_empty()) return make_empty_column(type_id::STRING); if (maxrepl == 0) return std::make_unique(strings.parent(), stream, mr); @@ -584,168 +579,6 @@ std::unique_ptr replace(strings_column_view con strings, chars_start, chars_end, d_target, d_repl, maxrepl, stream, mr); } -template <> -std::unique_ptr replace( - strings_column_view const& strings, - string_scalar const& target, - string_scalar const& repl, - int32_t maxrepl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - if (maxrepl == 0) return std::make_unique(strings.parent(), stream, mr); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid."); - CUDF_EXPECTS(target.size() > 0, "Parameter target must not be empty string."); - - string_view d_target(target.data(), target.size()); - string_view d_repl(repl.data(), repl.size()); - - // determine range of characters in the base column - auto const strings_count = strings.size(); - auto const offset_count = strings_count + 1; - auto const d_offsets = strings.offsets_begin(); - size_type chars_start = (strings.offset() == 0) ? 0 - : cudf::detail::get_value( - strings.offsets(), strings.offset(), stream); - size_type chars_end = (offset_count == strings.offsets().size()) - ? strings.chars_size(stream) - : cudf::detail::get_value( - strings.offsets(), strings.offset() + strings_count, stream); - return replace_char_parallel( - strings, chars_start, chars_end, d_target, d_repl, maxrepl, stream, mr); -} - -template <> -std::unique_ptr replace( - strings_column_view const& strings, - string_scalar const& target, - string_scalar const& repl, - int32_t maxrepl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - if (maxrepl == 0) return std::make_unique(strings.parent(), stream, mr); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid."); - CUDF_EXPECTS(target.size() > 0, "Parameter target must not be empty string."); - - string_view d_target(target.data(), target.size()); - string_view d_repl(repl.data(), repl.size()); - return replace_row_parallel(strings, d_target, d_repl, maxrepl, stream, mr); -} - -namespace { -/** - * @brief Function logic for the replace_slice API. - * - * This will perform a replace_slice operation on each string. - */ -struct replace_slice_fn { - column_device_view const d_strings; - string_view const d_repl; - size_type const start; - size_type const stop; - int32_t* d_offsets{}; - char* d_chars{}; - - __device__ void operator()(size_type idx) - { - if (d_strings.is_null(idx)) { - if (!d_chars) d_offsets[idx] = 0; - return; - } - auto const d_str = d_strings.element(idx); - auto const length = d_str.length(); - char const* in_ptr = d_str.data(); - auto const begin = d_str.byte_offset(((start < 0) || (start > length) ? length : start)); - auto const end = d_str.byte_offset(((stop < 0) || (stop > length) ? length : stop)); - - if (d_chars) { - char* out_ptr = d_chars + d_offsets[idx]; - - out_ptr = copy_and_increment(out_ptr, in_ptr, begin); // copy beginning - out_ptr = copy_string(out_ptr, d_repl); // insert replacement - out_ptr = copy_and_increment(out_ptr, // copy end - in_ptr + end, - d_str.size_bytes() - end); - } else { - d_offsets[idx] = d_str.size_bytes() + d_repl.size_bytes() - (end - begin); - } - } -}; - -} // namespace - -std::unique_ptr replace_slice(strings_column_view const& strings, - string_scalar const& repl, - size_type start, - size_type stop, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - if (strings.is_empty()) return make_empty_column(type_id::STRING); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - if (stop > 0) CUDF_EXPECTS(start <= stop, "Parameter start must be less than or equal to stop."); - - string_view d_repl(repl.data(), repl.size()); - - auto d_strings = column_device_view::create(strings.parent(), stream); - - // this utility calls the given functor to build the offsets and chars columns - auto [offsets_column, chars_column] = cudf::strings::detail::make_strings_children( - replace_slice_fn{*d_strings, d_repl, start, stop}, strings.size(), stream, mr); - - return make_strings_column(strings.size(), - std::move(offsets_column), - std::move(chars_column->release().data.release()[0]), - strings.null_count(), - cudf::detail::copy_bitmask(strings.parent(), stream, mr)); -} - -std::unique_ptr replace_nulls(strings_column_view const& strings, - string_scalar const& repl, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - size_type strings_count = strings.size(); - if (strings_count == 0) return make_empty_column(type_id::STRING); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); - - string_view d_repl(repl.data(), repl.size()); - - auto strings_column = column_device_view::create(strings.parent(), stream); - auto d_strings = *strings_column; - - // build offsets column - auto offsets_transformer_itr = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - cuda::proclaim_return_type([d_strings, d_repl] __device__(size_type idx) { - return d_strings.is_null(idx) ? d_repl.size_bytes() - : d_strings.element(idx).size_bytes(); - })); - auto [offsets_column, bytes] = cudf::detail::make_offsets_child_column( - offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); - auto d_offsets = offsets_column->view().data(); - - // build chars column - rmm::device_uvector chars(bytes, stream, mr); - auto d_chars = chars.data(); - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - strings_count, - [d_strings, d_repl, d_offsets, d_chars] __device__(size_type idx) { - string_view d_str = d_repl; - if (!d_strings.is_null(idx)) d_str = d_strings.element(idx); - memcpy(d_chars + d_offsets[idx], d_str.data(), d_str.size_bytes()); - }); - - return make_strings_column( - strings_count, std::move(offsets_column), chars.release(), 0, rmm::device_buffer{}); -} - } // namespace detail // external API @@ -761,16 +594,5 @@ std::unique_ptr replace(strings_column_view const& strings, return detail::replace(strings, target, repl, maxrepl, stream, mr); } -std::unique_ptr replace_slice(strings_column_view const& strings, - string_scalar const& repl, - size_type start, - size_type stop, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - CUDF_FUNC_RANGE(); - return detail::replace_slice(strings, repl, start, stop, stream, mr); -} - } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/replace/replace_nulls.cu b/cpp/src/strings/replace/replace_nulls.cu new file mode 100644 index 00000000000..26fb1c7819f --- /dev/null +++ b/cpp/src/strings/replace/replace_nulls.cu @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace cudf { +namespace strings { +namespace detail { + +std::unique_ptr replace_nulls(strings_column_view const& strings, + string_scalar const& repl, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + size_type strings_count = strings.size(); + if (strings_count == 0) return make_empty_column(type_id::STRING); + CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); + + string_view d_repl(repl.data(), repl.size()); + + auto strings_column = column_device_view::create(strings.parent(), stream); + auto d_strings = *strings_column; + + // build offsets column + auto offsets_transformer_itr = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([d_strings, d_repl] __device__(size_type idx) { + return d_strings.is_null(idx) ? d_repl.size_bytes() + : d_strings.element(idx).size_bytes(); + })); + auto [offsets_column, bytes] = cudf::strings::detail::make_offsets_child_column( + offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); + auto d_offsets = offsets_column->view().data(); + + // build chars column + rmm::device_uvector chars(bytes, stream, mr); + auto d_chars = chars.data(); + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + strings_count, + [d_strings, d_repl, d_offsets, d_chars] __device__(size_type idx) { + string_view d_str = d_repl; + if (!d_strings.is_null(idx)) d_str = d_strings.element(idx); + memcpy(d_chars + d_offsets[idx], d_str.data(), d_str.size_bytes()); + }); + + return make_strings_column( + strings_count, std::move(offsets_column), chars.release(), 0, rmm::device_buffer{}); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/replace/replace_slice.cu b/cpp/src/strings/replace/replace_slice.cu new file mode 100644 index 00000000000..4321f78d2d5 --- /dev/null +++ b/cpp/src/strings/replace/replace_slice.cu @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace cudf { +namespace strings { +namespace detail { +namespace { +/** + * @brief Function logic for the replace_slice API. + * + * This will perform a replace_slice operation on each string. + */ +struct replace_slice_fn { + column_device_view const d_strings; + string_view const d_repl; + size_type const start; + size_type const stop; + size_type* d_offsets{}; + char* d_chars{}; + + __device__ void operator()(size_type idx) + { + if (d_strings.is_null(idx)) { + if (!d_chars) d_offsets[idx] = 0; + return; + } + auto const d_str = d_strings.element(idx); + auto const length = d_str.length(); + char const* in_ptr = d_str.data(); + auto const begin = d_str.byte_offset(((start < 0) || (start > length) ? length : start)); + auto const end = d_str.byte_offset(((stop < 0) || (stop > length) ? length : stop)); + + if (d_chars) { + char* out_ptr = d_chars + d_offsets[idx]; + + out_ptr = copy_and_increment(out_ptr, in_ptr, begin); // copy beginning + out_ptr = copy_string(out_ptr, d_repl); // insert replacement + out_ptr = copy_and_increment(out_ptr, // copy end + in_ptr + end, + d_str.size_bytes() - end); + } else { + d_offsets[idx] = d_str.size_bytes() + d_repl.size_bytes() - (end - begin); + } + } +}; + +} // namespace + +std::unique_ptr replace_slice(strings_column_view const& strings, + string_scalar const& repl, + size_type start, + size_type stop, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + if (strings.is_empty()) return make_empty_column(type_id::STRING); + CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid."); + if (stop > 0) CUDF_EXPECTS(start <= stop, "Parameter start must be less than or equal to stop."); + + string_view d_repl(repl.data(), repl.size()); + + auto d_strings = column_device_view::create(strings.parent(), stream); + + // this utility calls the given functor to build the offsets and chars columns + auto [offsets_column, chars_column] = cudf::strings::detail::make_strings_children( + replace_slice_fn{*d_strings, d_repl, start, stop}, strings.size(), stream, mr); + + return make_strings_column(strings.size(), + std::move(offsets_column), + std::move(chars_column->release().data.release()[0]), + strings.null_count(), + cudf::detail::copy_bitmask(strings.parent(), stream, mr)); +} +} // namespace detail + +std::unique_ptr replace_slice(strings_column_view const& strings, + string_scalar const& repl, + size_type start, + size_type stop, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::replace_slice(strings, repl, start, stop, stream, mr); +} + +} // namespace strings +} // namespace cudf diff --git a/cpp/tests/strings/replace_tests.cpp b/cpp/tests/strings/replace_tests.cpp index f04bb832f09..726d9f95c7d 100644 --- a/cpp/tests/strings/replace_tests.cpp +++ b/cpp/tests/strings/replace_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,17 +20,12 @@ #include #include -#include +#include #include #include -#include -#include - #include -using algorithm = cudf::strings::detail::replace_algorithm; - struct StringsReplaceTest : public cudf::test::BaseFixture { cudf::test::strings_column_wrapper build_corpus() { @@ -47,6 +42,13 @@ struct StringsReplaceTest : public cudf::test::BaseFixture { h_strings.end(), thrust::make_transform_iterator(h_strings.begin(), [](auto str) { return str != nullptr; })); } + + std::unique_ptr build_large(cudf::column_view const& first, + cudf::column_view const& remaining) + { + return cudf::strings::concatenate(cudf::table_view( + {first, remaining, remaining, remaining, remaining, remaining, remaining, remaining})); + } }; TEST_F(StringsReplaceTest, Replace) @@ -64,26 +66,23 @@ TEST_F(StringsReplaceTest, Replace) cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); + auto target = cudf::string_scalar("the "); + auto replacement = cudf::string_scalar("++++ "); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("the "), cudf::string_scalar("++++ ")); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar("++++ "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar("++++ "), -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceReplLimit) { auto input = build_corpus(); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); // only remove the first occurrence of 'the ' std::vector h_expected{"quick brown fox jumps over the lazy dog", @@ -95,15 +94,16 @@ TEST_F(StringsReplaceTest, ReplaceReplLimit) nullptr}; cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("the "), cudf::string_scalar(""), 1); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar(""), 1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("the "), cudf::string_scalar(""), 1, stream, mr); + auto target = cudf::string_scalar("the "); + auto replacement = cudf::string_scalar(""); + auto results = cudf::strings::replace(strings_view, target, replacement, 1); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, input); + results = cudf::strings::replace(strings_view, target, replacement, 1); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceReplLimitInputSliced) @@ -119,22 +119,28 @@ TEST_F(StringsReplaceTest, ReplaceReplLimitInputSliced) nullptr}; cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); std::vector slice_indices{0, 2, 2, 3, 3, 7}; auto sliced_strings = cudf::slice(input, slice_indices); auto sliced_expected = cudf::slice(expected, slice_indices); + + auto input_large = build_large(input, input); + auto expected_large = build_large(expected, input); + + auto sliced_large = cudf::slice(input_large->view(), slice_indices); + auto sliced_expected_large = cudf::slice(expected_large->view(), slice_indices); + + auto target = cudf::string_scalar(" "); + auto replacement = cudf::string_scalar("--"); + for (size_t i = 0; i < sliced_strings.size(); ++i) { auto strings_view = cudf::strings_column_view(sliced_strings[i]); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected[i]); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected[i]); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement, 2); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected[i]); + + strings_view = cudf::strings_column_view(sliced_large[i]); + results = + cudf::strings::replace(strings_view, cudf::string_scalar(" "), cudf::string_scalar("--"), 2); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, sliced_expected_large[i]); } } @@ -158,68 +164,56 @@ TEST_F(StringsReplaceTest, ReplaceTargetOverlap) cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); + auto target = cudf::string_scalar("+++"); + auto replacement = cudf::string_scalar("plus "); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("+++"), cudf::string_scalar("plus ")); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+++"), cudf::string_scalar("plus "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+++"), cudf::string_scalar("plus "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + auto input_large = build_large(input->view(), input->view()); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceTargetOverlapsStrings) { auto input = build_corpus(); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); // replace all occurrences of 'dogthe' with '+' + auto target = cudf::string_scalar("dogthe"); + auto replacement = cudf::string_scalar("+"); + // should not replace anything unless it incorrectly matches across a string boundary - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("dogthe"), cudf::string_scalar("+")); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("dogthe"), cudf::string_scalar("+"), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("dogthe"), cudf::string_scalar("+"), -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); + + auto input_large = cudf::strings::concatenate( + cudf::table_view({input, input, input, input, input, input, input, input}), + cudf::string_scalar(" ")); + strings_view = cudf::strings_column_view(input_large->view()); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *input_large); } -TEST_F(StringsReplaceTest, ReplaceNullInput) +TEST_F(StringsReplaceTest, ReplaceAllNullInput) { std::vector h_null_strings(128); auto input = cudf::test::strings_column_wrapper( h_null_strings.begin(), h_null_strings.end(), thrust::make_constant_iterator(false)); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); - // replace all occurrences of '+' with '' - // should not replace anything as input is all null auto results = cudf::strings::replace(strings_view, cudf::string_scalar("+"), cudf::string_scalar("")); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+"), cudf::string_scalar(""), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("+"), cudf::string_scalar(""), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, input); } TEST_F(StringsReplaceTest, ReplaceEndOfString) { auto input = build_corpus(); auto strings_view = cudf::strings_column_view(input); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); // replace all occurrences of 'in' with ' ' std::vector h_expected{"the quick brown fox jumps over the lazy dog", @@ -233,39 +227,56 @@ TEST_F(StringsReplaceTest, ReplaceEndOfString) cudf::test::strings_column_wrapper expected( h_expected.begin(), h_expected.end(), cudf::test::iterators::nulls_from_nullptrs(h_expected)); - auto results = - cudf::strings::replace(strings_view, cudf::string_scalar("in"), cudf::string_scalar(" ")); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + auto target = cudf::string_scalar("in"); + auto replacement = cudf::string_scalar(" "); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("in"), cudf::string_scalar(" "), -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, cudf::string_scalar("in"), cudf::string_scalar(" "), -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); } TEST_F(StringsReplaceTest, ReplaceAdjacentMultiByteTarget) { - auto input = cudf::test::strings_column_wrapper({"ééééééé", "eéeéeée", "eeeeeee"}); + auto input = cudf::test::strings_column_wrapper({"ééééééééééééééééééééé", + "eéeéeéeeéeéeéeeéeéeée", + "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"}); auto strings_view = cudf::strings_column_view(input); // replace all occurrences of 'é' with 'e' - cudf::test::strings_column_wrapper expected({"eeeeeee", "eeeeeee", "eeeeeee"}); + cudf::test::strings_column_wrapper expected({"eeeeeeeeeeeeeeeeeeeee", + "eeeeeeeeeeeeeeeeeeeee", + "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"}); - auto stream = cudf::get_default_stream(); - auto mr = rmm::mr::get_current_device_resource(); + auto target = cudf::string_scalar("é"); + auto replacement = cudf::string_scalar("e"); - auto target = cudf::string_scalar("é", true, stream); - auto repl = cudf::string_scalar("e", true, stream); - auto results = cudf::strings::replace(strings_view, target, repl); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, target, repl, -1, stream, mr); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); - results = cudf::strings::detail::replace( - strings_view, target, repl, -1, stream, mr); + auto results = cudf::strings::replace(strings_view, target, replacement); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + auto input_large = build_large(input, input); + strings_view = cudf::strings_column_view(input_large->view()); + auto expected_large = build_large(expected, expected); + results = cudf::strings::replace(strings_view, target, replacement); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_large); +} + +TEST_F(StringsReplaceTest, ReplaceErrors) +{ + auto input = cudf::test::strings_column_wrapper({"this column intentionally left blank"}); + + auto target = cudf::string_scalar(" "); + auto replacement = cudf::string_scalar("_"); + auto null_input = cudf::string_scalar("", false); + auto empty_input = cudf::string_scalar(""); + auto sv = cudf::strings_column_view(input); + + EXPECT_THROW(cudf::strings::replace(sv, target, null_input), cudf::logic_error); + EXPECT_THROW(cudf::strings::replace(sv, null_input, replacement), cudf::logic_error); + EXPECT_THROW(cudf::strings::replace(sv, empty_input, replacement), cudf::logic_error); } TEST_F(StringsReplaceTest, ReplaceSlice) @@ -369,22 +380,30 @@ TEST_F(StringsReplaceTest, ReplaceMulti) TEST_F(StringsReplaceTest, ReplaceMultiLong) { - // The length of the strings are to trigger the code path governed by the AVG_CHAR_BYTES_THRESHOLD - // setting in the multi.cu. + // The length of the strings are to trigger the code path governed by the + // AVG_CHAR_BYTES_THRESHOLD setting in the multi.cu. auto input = cudf::test::strings_column_wrapper( {"This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions.", - "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" - "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" - "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" - "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + "12" + "3456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123" + "45" + "6789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456" + "78" + "9012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + "01" "2345678901234567890123456789", - "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012" - "345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345" - "678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678" - "901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901" + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + "12" + "3456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123" + "45" + "6789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456" + "78" + "9012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + "01" "2345678901234567890123456789", "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " "Test string for overlap check: bananaápple bananá ápplebananá banápple ápple bananá " @@ -410,11 +429,15 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions. " "This string needs to be very long to trigger the long-replace internal functions.", - "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" - "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x234" + "56" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x2345" + "6x" "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", - "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456" - "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x" + "0123456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x234" + "56" + "x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x23456x2345" + "6x" "23456x23456x23456x23456x23456x23456x23456x23456x23456x23456$$9", "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " "Test string for overlap check: bananaavocado PEAR avocadoPEAR banavocado avocado PEAR " @@ -445,8 +468,10 @@ TEST_F(StringsReplaceTest, ReplaceMultiLong) "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*" "23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*23456*9", "Test string for overlap check: banana* * ** ban* * * Test string for overlap check: " - "banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * * Test string for " - "overlap check: banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * *", + "banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * * Test string " + "for " + "overlap check: banana* * ** ban* * * Test string for overlap check: banana* * ** ban* * " + "*", "", ""}, {1, 1, 1, 1, 0, 1});