Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-replace target count logic for large strings #15807

Merged
merged 1 commit into from
May 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions cpp/src/strings/replace/multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,17 @@
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/functional>
#include <thrust/binary_search.h>
#include <thrust/copy.h>
#include <thrust/count.h>
#include <thrust/distance.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/optional.h>
#include <thrust/scan.h>
#include <thrust/transform.h>

namespace cudf {
Expand Down Expand Up @@ -262,6 +256,38 @@ struct replace_multi_parallel_fn {
device_span<string_view const> d_replacements;
};

constexpr int64_t block_size = 512; // number of threads per block
constexpr size_type bytes_per_thread = 4; // bytes processed per thread

/**
* @brief Count the number of targets in a strings column
*
* @param fn Functor containing has_target() function
* @param chars_bytes Number of bytes in the strings column
* @param d_output Result of the count
*/
CUDF_KERNEL void count_targets(replace_multi_parallel_fn fn, int64_t chars_bytes, int64_t* d_output)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();
auto const byte_idx = static_cast<int64_t>(idx) * bytes_per_thread;
auto const lane_idx = static_cast<cudf::size_type>(threadIdx.x);

using block_reduce = cub::BlockReduce<int64_t, block_size>;
__shared__ typename block_reduce::TempStorage temp_storage;

int64_t count = 0;
// each thread processes multiple bytes
for (auto i = byte_idx; (i < (byte_idx + bytes_per_thread)) && (i < chars_bytes); ++i) {
count += fn.has_target(i, chars_bytes);
}
auto const total = block_reduce(temp_storage).Reduce(count, cub::Sum());

if ((lane_idx == 0) && (total > 0)) {
cuda::atomic_ref<int64_t, cuda::thread_scope_block> ref{*d_output};
ref.fetch_add(total, cuda::std::memory_order_relaxed);
}
}

/**
* @brief Used by the copy-if function to produce target_pair objects
*
Expand Down Expand Up @@ -308,12 +334,11 @@ std::unique_ptr<column> replace_character_parallel(strings_column_view const& in

// Count the number of targets in the entire column.
// Note this may over-count in the case where a target spans adjacent strings.
auto target_count = thrust::count_if(
rmm::exec_policy_nosync(stream),
thrust::make_counting_iterator<int64_t>(0),
thrust::make_counting_iterator<int64_t>(chars_bytes),
[fn, chars_bytes] __device__(int64_t idx) { return fn.has_target(idx, chars_bytes); });

rmm::device_scalar<int64_t> d_count(0, stream);
auto const num_blocks = util::div_rounding_up_safe(
util::div_rounding_up_safe(chars_bytes, static_cast<int64_t>(bytes_per_thread)), block_size);
count_targets<<<num_blocks, block_size, 0, stream.value()>>>(fn, chars_bytes, d_count.data());
auto target_count = d_count.value(stream);
// Create a vector of every target position in the chars column.
// These may also include overlapping targets which will be resolved later.
auto targets_positions = rmm::device_uvector<int64_t>(target_count, stream);
Expand Down
Loading