Skip to content

Commit

Permalink
Fix target counting in strings char-parallel replace (#16017)
Browse files Browse the repository at this point in the history
Replace `thrust::count_if` call across int64 characters to use a custom kernel instead.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Srinivas Yadav (https://github.com/srinivasyadav18)

URL: #16017
  • Loading branch information
davidwendt authored Jun 17, 2024
1 parent bcdfe91 commit 56e8442
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
37 changes: 33 additions & 4 deletions cpp/src/strings/replace/replace.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,31 @@ struct replace_parallel_chars_fn {
cudf::size_type maxrepl;
};

template <int64_t block_size, size_type bytes_per_thread>
CUDF_KERNEL void count_targets_kernel(replace_parallel_chars_fn fn,
int64_t chars_bytes,
int64_t* d_output)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();
auto const byte_idx = static_cast<int64_t>(idx) * bytes_per_thread;
auto const lane_idx = static_cast<cudf::size_type>(threadIdx.x);

using block_reduce = cub::BlockReduce<int64_t, block_size>;
__shared__ typename block_reduce::TempStorage temp_storage;

int64_t count = 0;
// each thread processes multiple bytes
for (auto i = byte_idx; (i < (byte_idx + bytes_per_thread)) && (i < chars_bytes); ++i) {
count += fn.has_target(i);
}
auto const total = block_reduce(temp_storage).Reduce(count, cub::Sum());

if ((lane_idx == 0) && (total > 0)) {
cuda::atomic_ref<int64_t, cuda::thread_scope_device> ref{*d_output};
ref.fetch_add(total, cuda::std::memory_order_relaxed);
}
}

std::unique_ptr<column> replace_character_parallel(strings_column_view const& input,
string_view const& d_target,
string_view const& d_replacement,
Expand All @@ -260,10 +285,14 @@ std::unique_ptr<column> replace_character_parallel(strings_column_view const& in

// Count the number of targets in the entire column.
// Note this may over-count in the case where a target spans adjacent strings.
auto target_count = thrust::count_if(rmm::exec_policy_nosync(stream),
thrust::make_counting_iterator<int64_t>(0),
thrust::make_counting_iterator<int64_t>(chars_bytes),
[fn] __device__(int64_t idx) { return fn.has_target(idx); });
rmm::device_scalar<int64_t> d_target_count(0, stream);
constexpr int64_t block_size = 512;
constexpr size_type bytes_per_thread = 4;
auto const num_blocks = util::div_rounding_up_safe(
util::div_rounding_up_safe(chars_bytes, static_cast<int64_t>(bytes_per_thread)), block_size);
count_targets_kernel<block_size, bytes_per_thread>
<<<num_blocks, block_size, 0, stream.value()>>>(fn, chars_bytes, d_target_count.data());
auto target_count = d_target_count.value(stream);

// Create a vector of every target position in the chars column.
// These may also include overlapping targets which will be resolved later.
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ ConfigureTest(
large_strings/concatenate_tests.cpp
large_strings/case_tests.cpp
large_strings/large_strings_fixture.cpp
large_strings/many_strings_tests.cpp
large_strings/merge_tests.cpp
large_strings/parquet_tests.cpp
large_strings/replace_tests.cpp
large_strings/reshape_tests.cpp
large_strings/split_strings_tests.cpp
GPUS 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
#include <cudf/concatenate.hpp>
#include <cudf/copying.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/strings/combine.hpp>
#include <cudf/strings/replace.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/table/table_view.hpp>

#include <limits>
#include <vector>

struct StringsManyTest : public cudf::test::StringsLargeTest {};
struct ReplaceTest : public cudf::test::StringsLargeTest {};

TEST_F(StringsManyTest, Replace)
TEST_F(ReplaceTest, ReplaceLong)
{
auto const expected = this->very_long_column();
auto const view = cudf::column_view(expected);
Expand Down Expand Up @@ -65,3 +66,22 @@ TEST_F(StringsManyTest, Replace)
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(c, expected);
}
}

TEST_F(ReplaceTest, ReplaceWide)
{
auto const expected = this->long_column();
auto const view = cudf::column_view(expected);
auto const multiplier = 10;
auto const separator = cudf::string_scalar("|");
auto const input = cudf::strings::concatenate(
cudf::table_view(std::vector<cudf::column_view>(multiplier, view)), separator);

auto const input_view = cudf::strings_column_view(input->view());
auto const target = cudf::string_scalar("3"); // fake the actual replace;
auto const repl = cudf::string_scalar("3"); // logic still builds the output
auto result = cudf::strings::replace(input_view, target, repl);

auto sv = cudf::strings_column_view(result->view());
EXPECT_EQ(sv.offsets().type(), cudf::data_type{cudf::type_id::INT64});
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(input->view(), result->view());
}

0 comments on commit 56e8442

Please sign in to comment.