diff --git a/cpp/src/strings/replace/replace.cu b/cpp/src/strings/replace/replace.cu index 501e6d547e6..f7a3a3aea5c 100644 --- a/cpp/src/strings/replace/replace.cu +++ b/cpp/src/strings/replace/replace.cu @@ -238,6 +238,31 @@ struct replace_parallel_chars_fn { cudf::size_type maxrepl; }; +template +CUDF_KERNEL void count_targets_kernel(replace_parallel_chars_fn fn, + int64_t chars_bytes, + int64_t* d_output) +{ + auto const idx = cudf::detail::grid_1d::global_thread_id(); + auto const byte_idx = static_cast(idx) * bytes_per_thread; + auto const lane_idx = static_cast(threadIdx.x); + + using block_reduce = cub::BlockReduce; + __shared__ typename block_reduce::TempStorage temp_storage; + + int64_t count = 0; + // each thread processes multiple bytes + for (auto i = byte_idx; (i < (byte_idx + bytes_per_thread)) && (i < chars_bytes); ++i) { + count += fn.has_target(i); + } + auto const total = block_reduce(temp_storage).Reduce(count, cub::Sum()); + + if ((lane_idx == 0) && (total > 0)) { + cuda::atomic_ref ref{*d_output}; + ref.fetch_add(total, cuda::std::memory_order_relaxed); + } +} + std::unique_ptr replace_character_parallel(strings_column_view const& input, string_view const& d_target, string_view const& d_replacement, @@ -260,10 +285,14 @@ std::unique_ptr replace_character_parallel(strings_column_view const& in // Count the number of targets in the entire column. // Note this may over-count in the case where a target spans adjacent strings. - auto target_count = thrust::count_if(rmm::exec_policy_nosync(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(chars_bytes), - [fn] __device__(int64_t idx) { return fn.has_target(idx); }); + rmm::device_scalar d_target_count(0, stream); + constexpr int64_t block_size = 512; + constexpr size_type bytes_per_thread = 4; + auto const num_blocks = util::div_rounding_up_safe( + util::div_rounding_up_safe(chars_bytes, static_cast(bytes_per_thread)), block_size); + count_targets_kernel + <<>>(fn, chars_bytes, d_target_count.data()); + auto target_count = d_target_count.value(stream); // Create a vector of every target position in the chars column. // These may also include overlapping targets which will be resolved later. diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b153c4984c5..329edbe4d36 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -571,9 +571,9 @@ ConfigureTest( large_strings/concatenate_tests.cpp large_strings/case_tests.cpp large_strings/large_strings_fixture.cpp - large_strings/many_strings_tests.cpp large_strings/merge_tests.cpp large_strings/parquet_tests.cpp + large_strings/replace_tests.cpp large_strings/reshape_tests.cpp large_strings/split_strings_tests.cpp GPUS 1 diff --git a/cpp/tests/large_strings/many_strings_tests.cpp b/cpp/tests/large_strings/replace_tests.cpp similarity index 72% rename from cpp/tests/large_strings/many_strings_tests.cpp rename to cpp/tests/large_strings/replace_tests.cpp index 73fbb21d014..aa65ec0c010 100644 --- a/cpp/tests/large_strings/many_strings_tests.cpp +++ b/cpp/tests/large_strings/replace_tests.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -28,9 +29,9 @@ #include #include -struct StringsManyTest : public cudf::test::StringsLargeTest {}; +struct ReplaceTest : public cudf::test::StringsLargeTest {}; -TEST_F(StringsManyTest, Replace) +TEST_F(ReplaceTest, ReplaceLong) { auto const expected = this->very_long_column(); auto const view = cudf::column_view(expected); @@ -65,3 +66,22 @@ TEST_F(StringsManyTest, Replace) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(c, expected); } } + +TEST_F(ReplaceTest, ReplaceWide) +{ + auto const expected = this->long_column(); + auto const view = cudf::column_view(expected); + auto const multiplier = 10; + auto const separator = cudf::string_scalar("|"); + auto const input = cudf::strings::concatenate( + cudf::table_view(std::vector(multiplier, view)), separator); + + auto const input_view = cudf::strings_column_view(input->view()); + auto const target = cudf::string_scalar("3"); // fake the actual replace; + auto const repl = cudf::string_scalar("3"); // logic still builds the output + auto result = cudf::strings::replace(input_view, target, repl); + + auto sv = cudf::strings_column_view(result->view()); + EXPECT_EQ(sv.offsets().type(), cudf::data_type{cudf::type_id::INT64}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(input->view(), result->view()); +}