From 1d7a77be153c09b007410d6dc8538705fbfd73ab Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Wed, 6 Sep 2023 12:13:27 -0400 Subject: [PATCH] Use `cudf::thread_index_type` in `merge.cu` (#13972) This PR uses `cudf::thread_index_type` to avoid overflows. Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Bradley Dice (https://github.com/bdice) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/13972 --- cpp/src/merge/merge.cu | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/cpp/src/merge/merge.cu b/cpp/src/merge/merge.cu index 5c54bb5661c..c0765b48205 100644 --- a/cpp/src/merge/merge.cu +++ b/cpp/src/merge/merge.cu @@ -78,11 +78,14 @@ __global__ void materialize_merged_bitmask_kernel( size_type const num_destination_rows, index_type const* const __restrict__ merged_indices) { - size_type destination_row = threadIdx.x + blockIdx.x * blockDim.x; + auto const stride = detail::grid_1d::grid_stride(); - auto active_threads = __ballot_sync(0xffff'ffffu, destination_row < num_destination_rows); + auto tid = detail::grid_1d::global_thread_id(); - while (destination_row < num_destination_rows) { + auto active_threads = __ballot_sync(0xffff'ffffu, tid < num_destination_rows); + + while (tid < num_destination_rows) { + auto const destination_row = static_cast(tid); auto const [src_side, src_row] = merged_indices[destination_row]; bool const from_left{src_side == side::LEFT}; bool source_bit_is_valid{true}; @@ -99,8 +102,8 @@ __global__ void materialize_merged_bitmask_kernel( // Only one thread writes output if (0 == threadIdx.x % warpSize) { out_validity[word_index(destination_row)] = result_mask; } - destination_row += blockDim.x * gridDim.x; - active_threads = __ballot_sync(active_threads, destination_row < num_destination_rows); + tid += stride; + active_threads = __ballot_sync(active_threads, tid < num_destination_rows); } }