diff --git a/cpp/src/merge/merge.cu b/cpp/src/merge/merge.cu index 5c54bb5661c..c0765b48205 100644 --- a/cpp/src/merge/merge.cu +++ b/cpp/src/merge/merge.cu @@ -78,11 +78,14 @@ __global__ void materialize_merged_bitmask_kernel( size_type const num_destination_rows, index_type const* const __restrict__ merged_indices) { - size_type destination_row = threadIdx.x + blockIdx.x * blockDim.x; + auto const stride = detail::grid_1d::grid_stride(); - auto active_threads = __ballot_sync(0xffff'ffffu, destination_row < num_destination_rows); + auto tid = detail::grid_1d::global_thread_id(); - while (destination_row < num_destination_rows) { + auto active_threads = __ballot_sync(0xffff'ffffu, tid < num_destination_rows); + + while (tid < num_destination_rows) { + auto const destination_row = static_cast(tid); auto const [src_side, src_row] = merged_indices[destination_row]; bool const from_left{src_side == side::LEFT}; bool source_bit_is_valid{true}; @@ -99,8 +102,8 @@ __global__ void materialize_merged_bitmask_kernel( // Only one thread writes output if (0 == threadIdx.x % warpSize) { out_validity[word_index(destination_row)] = result_mask; } - destination_row += blockDim.x * gridDim.x; - active_threads = __ballot_sync(active_threads, destination_row < num_destination_rows); + tid += stride; + active_threads = __ballot_sync(active_threads, tid < num_destination_rows); } }