From 093bcc94ccf156a7e39339a7c4bb7e86543187de Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Tue, 16 Jul 2024 20:16:07 -0400 Subject: [PATCH] Update cudf::detail::grid_1d to use thread_index_type (#16276) Updates the `cudf::detail::grid_1d` to use `thread_index_type` instead of `int` and `size_type` for the number threads and blocks. This has become important for launching kernels with more threads than max `size_type` total bytes for warp-per-row and thread-per-byte algorithms. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Bradley Dice (https://github.com/bdice) - Vyas Ramasubramani (https://github.com/vyasr) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/16276 --- cpp/include/cudf/detail/utilities/cuda.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/cuda.cuh b/cpp/include/cudf/detail/utilities/cuda.cuh index f1775c6d6d7..5007af7f9f1 100644 --- a/cpp/include/cudf/detail/utilities/cuda.cuh +++ b/cpp/include/cudf/detail/utilities/cuda.cuh @@ -41,8 +41,8 @@ static constexpr size_type warp_size{32}; */ class grid_1d { public: - int const num_threads_per_block; - int const num_blocks; + thread_index_type const num_threads_per_block; + thread_index_type const num_blocks; /** * @param overall_num_elements The number of elements the kernel needs to * handle/process, in its main, one-dimensional/linear input (e.g. one or more @@ -55,9 +55,9 @@ class grid_1d { * than a single element; this affects the number of threads the grid must * contain */ - grid_1d(cudf::size_type overall_num_elements, - cudf::size_type num_threads_per_block, - cudf::size_type elements_per_thread = 1) + grid_1d(thread_index_type overall_num_elements, + thread_index_type num_threads_per_block, + thread_index_type elements_per_thread = 1) : num_threads_per_block(num_threads_per_block), num_blocks(util::div_rounding_up_safe(overall_num_elements, elements_per_thread * num_threads_per_block))