Skip to content

Commit

Permalink
Fix multi-replace target count logic for large strings (#15807)
Browse files Browse the repository at this point in the history
Replaces `thrust::count_if` with raw kernel counter to handle large strings (int64 offsets) and > 2GB strings columns.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Nghia Truong (https://github.com/ttnghia)

URL: #15807
  • Loading branch information
davidwendt authored May 31, 2024
1 parent 476db9f commit dec0354
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions cpp/src/strings/replace/multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,17 @@
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/functional>
#include <thrust/binary_search.h>
#include <thrust/copy.h>
#include <thrust/count.h>
#include <thrust/distance.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/optional.h>
#include <thrust/scan.h>
#include <thrust/transform.h>

namespace cudf {
Expand Down Expand Up @@ -262,6 +256,38 @@ struct replace_multi_parallel_fn {
device_span<string_view const> d_replacements;
};

constexpr int64_t block_size = 512; // number of threads per block
constexpr size_type bytes_per_thread = 4; // bytes processed per thread

/**
* @brief Count the number of targets in a strings column
*
* @param fn Functor containing has_target() function
* @param chars_bytes Number of bytes in the strings column
* @param d_output Result of the count
*/
CUDF_KERNEL void count_targets(replace_multi_parallel_fn fn, int64_t chars_bytes, int64_t* d_output)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();
auto const byte_idx = static_cast<int64_t>(idx) * bytes_per_thread;
auto const lane_idx = static_cast<cudf::size_type>(threadIdx.x);

using block_reduce = cub::BlockReduce<int64_t, block_size>;
__shared__ typename block_reduce::TempStorage temp_storage;

int64_t count = 0;
// each thread processes multiple bytes
for (auto i = byte_idx; (i < (byte_idx + bytes_per_thread)) && (i < chars_bytes); ++i) {
count += fn.has_target(i, chars_bytes);
}
auto const total = block_reduce(temp_storage).Reduce(count, cub::Sum());

if ((lane_idx == 0) && (total > 0)) {
cuda::atomic_ref<int64_t, cuda::thread_scope_block> ref{*d_output};
ref.fetch_add(total, cuda::std::memory_order_relaxed);
}
}

/**
* @brief Used by the copy-if function to produce target_pair objects
*
Expand Down Expand Up @@ -308,12 +334,11 @@ std::unique_ptr<column> replace_character_parallel(strings_column_view const& in

// Count the number of targets in the entire column.
// Note this may over-count in the case where a target spans adjacent strings.
auto target_count = thrust::count_if(
rmm::exec_policy_nosync(stream),
thrust::make_counting_iterator<int64_t>(0),
thrust::make_counting_iterator<int64_t>(chars_bytes),
[fn, chars_bytes] __device__(int64_t idx) { return fn.has_target(idx, chars_bytes); });

rmm::device_scalar<int64_t> d_count(0, stream);
auto const num_blocks = util::div_rounding_up_safe(
util::div_rounding_up_safe(chars_bytes, static_cast<int64_t>(bytes_per_thread)), block_size);
count_targets<<<num_blocks, block_size, 0, stream.value()>>>(fn, chars_bytes, d_count.data());
auto target_count = d_count.value(stream);
// Create a vector of every target position in the chars column.
// These may also include overlapping targets which will be resolved later.
auto targets_positions = rmm::device_uvector<int64_t>(target_count, stream);
Expand Down

0 comments on commit dec0354

Please sign in to comment.