From e7022fbc22eda538783e67f32d35ea8ea0798be8 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:01:53 -0500 Subject: [PATCH] Use thread_index_type in binary-ops jit kernel.cu (#17420) Follow on to #17354 to prevent overflow in jit kernel binary-ops. This uses the `thread_index_type` directly since the `detail/utilities/cuda.cuh` cannot be included in the jit'd kernel source. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Yunsong Wang (https://github.com/PointKernel) - Muhammad Haseeb (https://github.com/mhaseeb123) URL: https://github.com/rapidsai/cudf/pull/17420 --- cpp/src/binaryop/jit/kernel.cu | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/cpp/src/binaryop/jit/kernel.cu b/cpp/src/binaryop/jit/kernel.cu index 985fc87521c..1133e9ac22e 100644 --- a/cpp/src/binaryop/jit/kernel.cu +++ b/cpp/src/binaryop/jit/kernel.cu @@ -51,15 +51,10 @@ CUDF_KERNEL void kernel_v_v(cudf::size_type size, TypeLhs* lhs_data, TypeRhs* rhs_data) { - int tid = threadIdx.x; - int blkid = blockIdx.x; - int blksz = blockDim.x; - int gridsz = gridDim.x; + auto const start = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; + auto const step = static_cast(blockDim.x) * gridDim.x; - int start = tid + blkid * blksz; - int step = blksz * gridsz; - - for (cudf::size_type i = start; i < size; i += step) { + for (auto i = start; i < size; i += step) { out_data[i] = TypeOpe::template operate(lhs_data[i], rhs_data[i]); } } @@ -75,15 +70,10 @@ CUDF_KERNEL void kernel_v_v_with_validity(cudf::size_type size, cudf::bitmask_type const* rhs_mask, cudf::size_type rhs_offset) { - int tid = threadIdx.x; - int blkid = blockIdx.x; - int blksz = blockDim.x; - int gridsz = gridDim.x; - - int start = tid + blkid * blksz; - int step = blksz * gridsz; + auto const start = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; + auto const step = static_cast(blockDim.x) * gridDim.x; - for (cudf::size_type i = start; i < size; i += step) { + for (auto i = start; i < size; i += step) { bool output_valid = false; out_data[i] = TypeOpe::template operate( lhs_data[i],