Skip to content

Commit

Permalink
Use thread_index_type in binary-ops jit kernel.cu (#17420)
Browse files Browse the repository at this point in the history
Follow on to #17354 to prevent overflow in jit kernel binary-ops.
This uses the `thread_index_type` directly since the `detail/utilities/cuda.cuh` cannot be included in the jit'd kernel source.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Yunsong Wang (https://github.com/PointKernel)
  - Muhammad Haseeb (https://github.com/mhaseeb123)

URL: #17420
  • Loading branch information
davidwendt authored Nov 26, 2024
1 parent d10eae7 commit e7022fb
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions cpp/src/binaryop/jit/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,10 @@ CUDF_KERNEL void kernel_v_v(cudf::size_type size,
TypeLhs* lhs_data,
TypeRhs* rhs_data)
{
int tid = threadIdx.x;
int blkid = blockIdx.x;
int blksz = blockDim.x;
int gridsz = gridDim.x;
auto const start = threadIdx.x + static_cast<cudf::thread_index_type>(blockIdx.x) * blockDim.x;
auto const step = static_cast<cudf::thread_index_type>(blockDim.x) * gridDim.x;

int start = tid + blkid * blksz;
int step = blksz * gridsz;

for (cudf::size_type i = start; i < size; i += step) {
for (auto i = start; i < size; i += step) {
out_data[i] = TypeOpe::template operate<TypeOut, TypeLhs, TypeRhs>(lhs_data[i], rhs_data[i]);
}
}
Expand All @@ -75,15 +70,10 @@ CUDF_KERNEL void kernel_v_v_with_validity(cudf::size_type size,
cudf::bitmask_type const* rhs_mask,
cudf::size_type rhs_offset)
{
int tid = threadIdx.x;
int blkid = blockIdx.x;
int blksz = blockDim.x;
int gridsz = gridDim.x;

int start = tid + blkid * blksz;
int step = blksz * gridsz;
auto const start = threadIdx.x + static_cast<cudf::thread_index_type>(blockIdx.x) * blockDim.x;
auto const step = static_cast<cudf::thread_index_type>(blockDim.x) * gridDim.x;

for (cudf::size_type i = start; i < size; i += step) {
for (auto i = start; i < size; i += step) {
bool output_valid = false;
out_data[i] = TypeOpe::template operate<TypeOut, TypeLhs, TypeRhs>(
lhs_data[i],
Expand Down

0 comments on commit e7022fb

Please sign in to comment.