From 8e412b4b20f140e4f86ffbe7544df084b1a5731e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Akif=20=C3=87=C3=96RD=C3=9CK?= Date: Thu, 18 May 2023 03:20:12 +0200 Subject: [PATCH] Add generic reduction functions and separate reductions/warp_primitives (#1470) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds bunch of new device reduction functions such as: - Generic device reductions that takes reduction operator as argument. - Ranked reductions to return the index/rank of the reduced value. - Weighted random reduction to have probabilistic reduction using conditional probability. - Binary reduction to reduce binary values more efficiently. There are tests implemented for all device reduction operations. This PR also separates warp primitives to the `warp_primitives.cuh`. All reduction functions are moved to `reduction.cuh` Authors: - Akif ÇÖRDÜK (https://github.com/akifcorduk) - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1470 --- cpp/include/raft/random/device/sample.cuh | 104 ++++++ cpp/include/raft/util/cuda_utils.cuh | 312 +----------------- cpp/include/raft/util/device_loads_stores.cuh | 57 ++++ cpp/include/raft/util/pow2_utils.cuh | 11 +- cpp/include/raft/util/reduction.cuh | 202 ++++++++++++ cpp/include/raft/util/warp_primitives.cuh | 259 +++++++++++++++ cpp/test/CMakeLists.txt | 12 +- cpp/test/util/reduction.cu | 196 +++++++++++ 8 files changed, 841 insertions(+), 312 deletions(-) create mode 100644 cpp/include/raft/random/device/sample.cuh create mode 100644 cpp/include/raft/util/reduction.cuh create mode 100644 cpp/include/raft/util/warp_primitives.cuh create mode 100644 cpp/test/util/reduction.cu diff --git a/cpp/include/raft/random/device/sample.cuh b/cpp/include/raft/random/device/sample.cuh new file mode 100644 index 0000000000..f08db3e0a2 --- /dev/null +++ b/cpp/include/raft/random/device/sample.cuh @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace raft::random::device { + +/** + * @brief warp-level random sampling of an index. + * It selects an index with the given discrete probability + * distribution(represented by weights of each index) + * @param rng random number generator, must have next_u32() function + * @param weight weight of the rank/index. + * @param idx index to be used as rank + * @return only the thread0 will contain valid reduced result + */ +template +DI T warp_random_sample(rng_t& rng, T& weight, i_t& idx) +{ + // Todo(#1491): benchmark whether a scan and then selecting within the ranges is more efficient. + static_assert(std::is_integral::value, "The type T must be an integral type."); +#pragma unroll + for (i_t offset = raft::WarpSize / 2; offset > 0; offset /= 2) { + T tmp_weight = shfl(weight, laneId() + offset); + i_t tmp_idx = shfl(idx, laneId() + offset); + T sum = (tmp_weight + weight); + weight = sum; + if (sum != 0) { + i_t rnd_number = (rng.next_u32() % sum); + if (rnd_number < tmp_weight) { idx = tmp_idx; } + } + } +} + +/** + * @brief 1-D block-level random sampling of an index. + * It selects an index with the given discrete probability + * distribution(represented by weights of each index) + * + * Let w_i be the weight stored on thread i. We calculate the cumulative distribution function + * F_i = sum_{k=0..i} weight_i. + * Sequentially, we could select one of the elements with with the desired probability using the + * following method. We can consider that each element has a subinterval assigned: [F_{i-1}, F_i). + * We generate a uniform random number in the [0, F_i) range, and check which subinterval it falls. + * We return idx corresponding to the selected subinterval. + * In parallel, we do a tree reduction and make a selection at every step when we combine two + * values. + * @param rng random number generator, must have next_u32() function + * @param shbuf shared memory region needed for storing intermediate results. It + * must alteast be of size: `(sizeof(T) + sizeof(i_t)) * WarpSize` + * @param weight weight of the rank/index. + * @param idx index to be used as rank + * @return only the thread0 will contain valid reduced result + */ +template +DI i_t block_random_sample(rng_t rng, T* shbuf, T weight = 1, i_t idx = threadIdx.x) +{ + T* values = shbuf; + i_t* indices = (i_t*)&shbuf[WarpSize]; + i_t wid = threadIdx.x / WarpSize; + i_t nWarps = (blockDim.x + WarpSize - 1) / WarpSize; + warp_random_sample(rng, weight, idx); // Each warp performs partial reduction + i_t lane = laneId(); + if (lane == 0) { + values[wid] = weight; // Write reduced value to shared memory + indices[wid] = idx; // Write reduced value to shared memory + } + + __syncthreads(); // Wait for all partial reductions + + // read from shared memory only if that warp existed + if (lane < nWarps) { + weight = values[lane]; + idx = indices[lane]; + } else { + weight = 0; + idx = -1; + } + __syncthreads(); + if (wid == 0) warp_random_sample(rng, weight, idx); + return idx; +} + +} // namespace raft::random::device \ No newline at end of file diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 687a6b4651..0523dcc81c 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -23,7 +23,10 @@ #include #include #include +// For backward compatibility, we include the follow headers. They contain +// functionality that were previously contained in cuda_utils.cuh #include +#include namespace raft { @@ -523,238 +526,6 @@ DI double maxPrim(double x, double y) } /** @} */ -/** apply a warp-wide fence (useful from Volta+ archs) */ -DI void warpFence() -{ -#if __CUDA_ARCH__ >= 700 - __syncwarp(); -#endif -} - -/** warp-wide any boolean aggregator */ -DI bool any(bool inFlag, uint32_t mask = 0xffffffffu) -{ -#if CUDART_VERSION >= 9000 - inFlag = __any_sync(mask, inFlag); -#else - inFlag = __any(inFlag); -#endif - return inFlag; -} - -/** warp-wide all boolean aggregator */ -DI bool all(bool inFlag, uint32_t mask = 0xffffffffu) -{ -#if CUDART_VERSION >= 9000 - inFlag = __all_sync(mask, inFlag); -#else - inFlag = __all(inFlag); -#endif - return inFlag; -} - -/** For every thread in the warp, set the corresponding bit to the thread's flag value. */ -DI uint32_t ballot(bool inFlag, uint32_t mask = 0xffffffffu) -{ -#if CUDART_VERSION >= 9000 - return __ballot_sync(mask, inFlag); -#else - return __ballot(inFlag); -#endif -} - -/** True CUDA alignment of a type (adapted from CUB) */ -template -struct cuda_alignment { - struct Pad { - T val; - char byte; - }; - - static constexpr int bytes = sizeof(Pad) - sizeof(T); -}; - -template -struct is_multiple { - static constexpr int large_align_bytes = cuda_alignment::bytes; - static constexpr int unit_align_bytes = cuda_alignment::bytes; - static constexpr bool value = - (sizeof(LargeT) % sizeof(UnitT) == 0) && (large_align_bytes % unit_align_bytes == 0); -}; - -template -inline constexpr bool is_multiple_v = is_multiple::value; - -template -struct is_shuffleable { - static constexpr bool value = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || std::is_same_v; -}; - -template -inline constexpr bool is_shuffleable_v = is_shuffleable::value; - -/** - * @brief Shuffle the data inside a warp - * @tparam T the data type - * @param val value to be shuffled - * @param srcLane lane from where to shuffle - * @param width lane width - * @param mask mask of participating threads (Volta+) - * @return the shuffled data - */ -template -DI std::enable_if_t, T> shfl(T val, - int srcLane, - int width = WarpSize, - uint32_t mask = 0xffffffffu) -{ -#if CUDART_VERSION >= 9000 - return __shfl_sync(mask, val, srcLane, width); -#else - return __shfl(val, srcLane, width); -#endif -} - -/// Overload of shfl for data types not supported by the CUDA intrinsics -template -DI std::enable_if_t, T> shfl(T val, - int srcLane, - int width = WarpSize, - uint32_t mask = 0xffffffffu) -{ - using UnitT = - std::conditional_t, - unsigned int, - std::conditional_t, unsigned short, unsigned char>>; - - constexpr int n_words = sizeof(T) / sizeof(UnitT); - - T output; - UnitT* output_alias = reinterpret_cast(&output); - UnitT* input_alias = reinterpret_cast(&val); - - unsigned int shuffle_word; - shuffle_word = shfl((unsigned int)input_alias[0], srcLane, width, mask); - output_alias[0] = shuffle_word; - -#pragma unroll - for (int i = 1; i < n_words; ++i) { - shuffle_word = shfl((unsigned int)input_alias[i], srcLane, width, mask); - output_alias[i] = shuffle_word; - } - - return output; -} - -/** - * @brief Shuffle the data inside a warp from lower lane IDs - * @tparam T the data type - * @param val value to be shuffled - * @param delta lower lane ID delta from where to shuffle - * @param width lane width - * @param mask mask of participating threads (Volta+) - * @return the shuffled data - */ -template -DI std::enable_if_t, T> shfl_up(T val, - int delta, - int width = WarpSize, - uint32_t mask = 0xffffffffu) -{ -#if CUDART_VERSION >= 9000 - return __shfl_up_sync(mask, val, delta, width); -#else - return __shfl_up(val, delta, width); -#endif -} - -/// Overload of shfl_up for data types not supported by the CUDA intrinsics -template -DI std::enable_if_t, T> shfl_up(T val, - int delta, - int width = WarpSize, - uint32_t mask = 0xffffffffu) -{ - using UnitT = - std::conditional_t, - unsigned int, - std::conditional_t, unsigned short, unsigned char>>; - - constexpr int n_words = sizeof(T) / sizeof(UnitT); - - T output; - UnitT* output_alias = reinterpret_cast(&output); - UnitT* input_alias = reinterpret_cast(&val); - - unsigned int shuffle_word; - shuffle_word = shfl_up((unsigned int)input_alias[0], delta, width, mask); - output_alias[0] = shuffle_word; - -#pragma unroll - for (int i = 1; i < n_words; ++i) { - shuffle_word = shfl_up((unsigned int)input_alias[i], delta, width, mask); - output_alias[i] = shuffle_word; - } - - return output; -} - -/** - * @brief Shuffle the data inside a warp - * @tparam T the data type - * @param val value to be shuffled - * @param laneMask mask to be applied in order to perform xor shuffle - * @param width lane width - * @param mask mask of participating threads (Volta+) - * @return the shuffled data - */ -template -DI std::enable_if_t, T> shfl_xor(T val, - int laneMask, - int width = WarpSize, - uint32_t mask = 0xffffffffu) -{ -#if CUDART_VERSION >= 9000 - return __shfl_xor_sync(mask, val, laneMask, width); -#else - return __shfl_xor(val, laneMask, width); -#endif -} - -/// Overload of shfl_xor for data types not supported by the CUDA intrinsics -template -DI std::enable_if_t, T> shfl_xor(T val, - int laneMask, - int width = WarpSize, - uint32_t mask = 0xffffffffu) -{ - using UnitT = - std::conditional_t, - unsigned int, - std::conditional_t, unsigned short, unsigned char>>; - - constexpr int n_words = sizeof(T) / sizeof(UnitT); - - T output; - UnitT* output_alias = reinterpret_cast(&output); - UnitT* input_alias = reinterpret_cast(&val); - - unsigned int shuffle_word; - shuffle_word = shfl_xor((unsigned int)input_alias[0], laneMask, width, mask); - output_alias[0] = shuffle_word; - -#pragma unroll - for (int i = 1; i < n_words; ++i) { - shuffle_word = shfl_xor((unsigned int)input_alias[i], laneMask, width, mask); - output_alias[i] = shuffle_word; - } - - return output; -} - /** * @brief Four-way byte dot product-accumulate. * @tparam T Four-byte integer: int or unsigned int @@ -816,83 +587,6 @@ DI auto dp4a(unsigned int a, unsigned int b, unsigned int c) -> unsigned int #endif } -/** - * @brief Logical-warp-level reduction - * @tparam logicalWarpSize Logical warp size (2, 4, 8, 16 or 32) - * @tparam T Value type to be reduced - * @tparam ReduceLambda Reduction operation type - * @param val input value - * @param reduce_op Reduction operation - * @return Reduction result. All lanes will have the valid result. - */ -template -DI T logicalWarpReduce(T val, ReduceLambda reduce_op) -{ -#pragma unroll - for (int i = logicalWarpSize / 2; i > 0; i >>= 1) { - T tmp = shfl_xor(val, i); - val = reduce_op(val, tmp); - } - return val; -} - -/** - * @brief Warp-level reduction - * @tparam T Value type to be reduced - * @tparam ReduceLambda Reduction operation type - * @param val input value - * @param reduce_op Reduction operation - * @return Reduction result. All lanes will have the valid result. - * @note Why not cub? Because cub doesn't seem to allow working with arbitrary - * number of warps in a block. All threads in the warp must enter this - * function together - */ -template -DI T warpReduce(T val, ReduceLambda reduce_op) -{ - return logicalWarpReduce(val, reduce_op); -} - -/** - * @brief Warp-level sum reduction - * @tparam T Value type to be reduced - * @param val input value - * @return Reduction result. All lanes will have the valid result. - * @note Why not cub? Because cub doesn't seem to allow working with arbitrary - * number of warps in a block. All threads in the warp must enter this - * function together - */ -template -DI T warpReduce(T val) -{ - return warpReduce(val, raft::add_op{}); -} - -/** - * @brief 1-D block-level sum reduction - * @param val input value - * @param smem shared memory region needed for storing intermediate results. It - * must alteast be of size: `sizeof(T) * nWarps` - * @return only the thread0 will contain valid reduced result - * @note Why not cub? Because cub doesn't seem to allow working with arbitrary - * number of warps in a block. All threads in the block must enter this - * function together - * @todo Expand this to support arbitrary reduction ops - */ -template -DI T blockReduce(T val, char* smem) -{ - auto* sTemp = reinterpret_cast(smem); - int nWarps = (blockDim.x + WarpSize - 1) / WarpSize; - int lid = laneId(); - int wid = threadIdx.x / WarpSize; - val = warpReduce(val); - if (lid == 0) sTemp[wid] = val; - __syncthreads(); - val = lid < nWarps ? sTemp[lid] : T(0); - return warpReduce(val); -} - /** * @brief Simple utility function to determine whether user_stream or one of the * internal streams should be used. diff --git a/cpp/include/raft/util/device_loads_stores.cuh b/cpp/include/raft/util/device_loads_stores.cuh index c9bda26b81..e3d54c51f5 100644 --- a/cpp/include/raft/util/device_loads_stores.cuh +++ b/cpp/include/raft/util/device_loads_stores.cuh @@ -17,6 +17,7 @@ #pragma once #include // uintX_t +#include #include // DI namespace raft { @@ -534,6 +535,62 @@ DI void ldg(int8_t (&x)[1], const int8_t* const& addr) x[0] = x_int; } +/** + * @brief Executes a 1D block strided copy + * @param dst destination pointer + * @param src source pointer + * @param size number of items to copy + */ +template +DI void block_copy(T* dst, const T* src, const size_t size) +{ + for (auto i = threadIdx.x; i < size; i += blockDim.x) { + dst[i] = src[i]; + } +} + +/** + * @brief Executes a 1D block strided copy + * @param dst span of destination pointer + * @param src span of source pointer + * @param size number of items to copy + */ +template +DI void block_copy(raft::device_span dst, + const raft::device_span src, + const size_t size) +{ + assert(src.size() >= size); + assert(dst.size() >= size); + block_copy(dst.data(), src.data(), size); +} + +/** + * @brief Executes a 1D block strided copy + * @param dst span of destination pointer + * @param src span of source pointer + * @param size number of items to copy + */ +template +DI void block_copy(raft::device_span dst, const raft::device_span src, const size_t size) +{ + assert(src.size() >= size); + assert(dst.size() >= size); + block_copy(dst.data(), src.data(), size); +} + +/** + * @brief Executes a 1D block strided copy + * @param dst span of destination pointer + * @param src span of source pointer + */ +template +DI void block_copy(raft::device_span dst, const raft::device_span src) +{ + assert(dst.size() >= src.size()); + block_copy(dst, src, src.size()); +} + /** @} */ } // namespace raft diff --git a/cpp/include/raft/util/pow2_utils.cuh b/cpp/include/raft/util/pow2_utils.cuh index 3b42682816..68b35837b6 100644 --- a/cpp/include/raft/util/pow2_utils.cuh +++ b/cpp/include/raft/util/pow2_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,6 +81,15 @@ struct Pow2 { return x >> I(Log2); } + /** + * Rounds up the value to next power of two. + */ + template + Pow2_FUNC_QUALIFIER Pow2_WHEN_INTEGRAL(I) round_up_pow2(I val) noexcept + { + return 1 << (log2(val) + 1); + } + /** * x modulo Value operation (remainder of the `div(x)`) * (same as `x % Value` in Python). diff --git a/cpp/include/raft/util/reduction.cuh b/cpp/include/raft/util/reduction.cuh new file mode 100644 index 0000000000..74c57b4ca2 --- /dev/null +++ b/cpp/include/raft/util/reduction.cuh @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Logical-warp-level reduction + * @tparam logicalWarpSize Logical warp size (2, 4, 8, 16 or 32) + * @tparam T Value type to be reduced + * @tparam ReduceLambda Reduction operation type + * @param val input value + * @param reduce_op Reduction operation + * @return Reduction result. All lanes will have the valid result. + */ +template +DI T logicalWarpReduce(T val, ReduceLambda reduce_op) +{ +#pragma unroll + for (int i = logicalWarpSize / 2; i > 0; i >>= 1) { + T tmp = shfl_xor(val, i); + val = reduce_op(val, tmp); + } + return val; +} + +/** + * @brief Warp-level reduction + * @tparam T Value type to be reduced + * @tparam ReduceLambda Reduction operation type + * @param val input value + * @param reduce_op Reduction operation + * @return Reduction result. All lanes will have the valid result. + * @note Why not cub? Because cub doesn't seem to allow working with arbitrary + * number of warps in a block. All threads in the warp must enter this + * function together + */ +template +DI T warpReduce(T val, ReduceLambda reduce_op) +{ + return logicalWarpReduce(val, reduce_op); +} + +/** + * @brief Warp-level reduction + * @tparam T Value type to be reduced + * @param val input value + * @return Reduction result. All lanes will have the valid result. + * @note Why not cub? Because cub doesn't seem to allow working with arbitrary + * number of warps in a block. All threads in the warp must enter this + * function together + */ +template +DI T warpReduce(T val) +{ + return warpReduce(val, raft::add_op{}); +} + +/** + * @brief 1-D block-level reduction + * @param val input value + * @param smem shared memory region needed for storing intermediate results. It + * must alteast be of size: `sizeof(T) * nWarps` + * @param reduce_op a binary reduction operation. + * @return only the thread0 will contain valid reduced result + * @note Why not cub? Because cub doesn't seem to allow working with arbitrary + * number of warps in a block. All threads in the block must enter this + * function together. cub also uses too many registers + */ +template +DI T blockReduce(T val, char* smem, ReduceLambda reduce_op = raft::add_op{}) +{ + auto* sTemp = reinterpret_cast(smem); + int nWarps = (blockDim.x + WarpSize - 1) / WarpSize; + int lid = laneId(); + int wid = threadIdx.x / WarpSize; + val = warpReduce(val, reduce_op); + if (lid == 0) sTemp[wid] = val; + __syncthreads(); + val = lid < nWarps ? sTemp[lid] : T(0); + return warpReduce(val, reduce_op); +} + +/** + * @brief 1-D warp-level ranked reduction which returns the value and rank. + * thread 0 will have valid result and rank(idx). + * @param val input value + * @param idx index to be used as rank + * @param reduce_op a binary reduction operation. + * @return only the thread0 will contain valid reduced result + */ +template +DI void warpRankedReduce(T& val, i_t& idx, ReduceLambda reduce_op = raft::min_op{}) +{ +#pragma unroll + for (i_t offset = WarpSize / 2; offset > 0; offset /= 2) { + T tmpVal = shfl(val, laneId() + offset); + i_t tmpIdx = shfl(idx, laneId() + offset); + if (reduce_op(tmpVal, val) == tmpVal) { + val = tmpVal; + idx = tmpIdx; + } + } +} + +/** + * @brief 1-D block-level ranked reduction which returns the value and rank. + * thread 0 will have valid result and rank(idx). + * @param val input value + * @param shbuf shared memory region needed for storing intermediate results. It + * must alteast be of size: `(sizeof(T) + sizeof(i_t)) * WarpSize` + * @param idx index to be used as rank + * @param reduce_op binary min or max operation. + * @return only the thread0 will contain valid reduced result + */ +template +DI std::pair blockRankedReduce(T val, + T* shbuf, + i_t idx = threadIdx.x, + ReduceLambda reduce_op = raft::min_op{}) +{ + T* values = shbuf; + i_t* indices = (i_t*)&shbuf[WarpSize]; + i_t wid = threadIdx.x / WarpSize; + i_t nWarps = (blockDim.x + WarpSize - 1) / WarpSize; + warpRankedReduce(val, idx, reduce_op); // Each warp performs partial reduction + i_t lane = laneId(); + if (lane == 0) { + values[wid] = val; // Write reduced value to shared memory + indices[wid] = idx; // Write reduced value to shared memory + } + + __syncthreads(); // Wait for all partial reductions + + // read from shared memory only if that warp existed + if (lane < nWarps) { + val = values[lane]; + idx = indices[lane]; + } else { + // get the min if it is a max op, get the max if it is a min op + val = reduce_op(std::numeric_limits::min(), std::numeric_limits::max()) == + std::numeric_limits::min() + ? std::numeric_limits::max() + : std::numeric_limits::min(); + idx = -1; + } + __syncthreads(); + if (wid == 0) warpRankedReduce(val, idx, reduce_op); + return std::pair{val, idx}; +} + +/** + * @brief Executes a 1d binary block reduce + * @param val binary value to be reduced across the thread block + * @param shmem memory needed for the reduction. It should be at least of size blockDim.x/WarpSize + * @return only the thread0 will contain valid reduced result + */ +template +DI i_t binaryBlockReduce(i_t val, i_t* shmem) +{ + static_assert(BLOCK_SIZE <= 1024); + assert(val == 0 || val == 1); + const uint32_t mask = __ballot_sync(~0, val); + const uint32_t n_items = __popc(mask); + + // Each first thread of the warp + if (threadIdx.x % WarpSize == 0) { shmem[threadIdx.x / WarpSize] = n_items; } + __syncthreads(); + + val = (threadIdx.x < BLOCK_SIZE / WarpSize) ? shmem[threadIdx.x] : 0; + + if (threadIdx.x < WarpSize) { + return warpReduce(val); + } + // Only first warp gets the results + else { + return -1; + } +} + +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/util/warp_primitives.cuh b/cpp/include/raft/util/warp_primitives.cuh new file mode 100644 index 0000000000..94fddbe0f3 --- /dev/null +++ b/cpp/include/raft/util/warp_primitives.cuh @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace raft { + +/** True CUDA alignment of a type (adapted from CUB) */ +template +struct cuda_alignment { + struct Pad { + T val; + char byte; + }; + + static constexpr int bytes = sizeof(Pad) - sizeof(T); +}; + +template +struct is_multiple { + static constexpr int large_align_bytes = cuda_alignment::bytes; + static constexpr int unit_align_bytes = cuda_alignment::bytes; + static constexpr bool value = + (sizeof(LargeT) % sizeof(UnitT) == 0) && (large_align_bytes % unit_align_bytes == 0); +}; + +template +inline constexpr bool is_multiple_v = is_multiple::value; + +/** apply a warp-wide fence (useful from Volta+ archs) */ +DI void warpFence() +{ +#if __CUDA_ARCH__ >= 700 + __syncwarp(); +#endif +} + +/** warp-wide any boolean aggregator */ +DI bool any(bool inFlag, uint32_t mask = 0xffffffffu) +{ +#if CUDART_VERSION >= 9000 + inFlag = __any_sync(mask, inFlag); +#else + inFlag = __any(inFlag); +#endif + return inFlag; +} + +/** warp-wide all boolean aggregator */ +DI bool all(bool inFlag, uint32_t mask = 0xffffffffu) +{ +#if CUDART_VERSION >= 9000 + inFlag = __all_sync(mask, inFlag); +#else + inFlag = __all(inFlag); +#endif + return inFlag; +} + +/** For every thread in the warp, set the corresponding bit to the thread's flag value. */ +DI uint32_t ballot(bool inFlag, uint32_t mask = 0xffffffffu) +{ +#if CUDART_VERSION >= 9000 + return __ballot_sync(mask, inFlag); +#else + return __ballot(inFlag); +#endif +} + +template +struct is_shuffleable { + static constexpr bool value = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v; +}; + +template +inline constexpr bool is_shuffleable_v = is_shuffleable::value; + +/** + * @brief Shuffle the data inside a warp + * @tparam T the data type + * @param val value to be shuffled + * @param srcLane lane from where to shuffle + * @param width lane width + * @param mask mask of participating threads (Volta+) + * @return the shuffled data + */ +template +DI std::enable_if_t, T> shfl(T val, + int srcLane, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ +#if CUDART_VERSION >= 9000 + return __shfl_sync(mask, val, srcLane, width); +#else + return __shfl(val, srcLane, width); +#endif +} + +/// Overload of shfl for data types not supported by the CUDA intrinsics +template +DI std::enable_if_t, T> shfl(T val, + int srcLane, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + using UnitT = + std::conditional_t, + unsigned int, + std::conditional_t, unsigned short, unsigned char>>; + + constexpr int n_words = sizeof(T) / sizeof(UnitT); + + T output; + UnitT* output_alias = reinterpret_cast(&output); + UnitT* input_alias = reinterpret_cast(&val); + + unsigned int shuffle_word; + shuffle_word = shfl((unsigned int)input_alias[0], srcLane, width, mask); + output_alias[0] = shuffle_word; + +#pragma unroll + for (int i = 1; i < n_words; ++i) { + shuffle_word = shfl((unsigned int)input_alias[i], srcLane, width, mask); + output_alias[i] = shuffle_word; + } + + return output; +} + +/** + * @brief Shuffle the data inside a warp from lower lane IDs + * @tparam T the data type + * @param val value to be shuffled + * @param delta lower lane ID delta from where to shuffle + * @param width lane width + * @param mask mask of participating threads (Volta+) + * @return the shuffled data + */ +template +DI std::enable_if_t, T> shfl_up(T val, + int delta, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ +#if CUDART_VERSION >= 9000 + return __shfl_up_sync(mask, val, delta, width); +#else + return __shfl_up(val, delta, width); +#endif +} + +/// Overload of shfl_up for data types not supported by the CUDA intrinsics +template +DI std::enable_if_t, T> shfl_up(T val, + int delta, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + using UnitT = + std::conditional_t, + unsigned int, + std::conditional_t, unsigned short, unsigned char>>; + + constexpr int n_words = sizeof(T) / sizeof(UnitT); + + T output; + UnitT* output_alias = reinterpret_cast(&output); + UnitT* input_alias = reinterpret_cast(&val); + + unsigned int shuffle_word; + shuffle_word = shfl_up((unsigned int)input_alias[0], delta, width, mask); + output_alias[0] = shuffle_word; + +#pragma unroll + for (int i = 1; i < n_words; ++i) { + shuffle_word = shfl_up((unsigned int)input_alias[i], delta, width, mask); + output_alias[i] = shuffle_word; + } + + return output; +} + +/** + * @brief Shuffle the data inside a warp + * @tparam T the data type + * @param val value to be shuffled + * @param laneMask mask to be applied in order to perform xor shuffle + * @param width lane width + * @param mask mask of participating threads (Volta+) + * @return the shuffled data + */ +template +DI std::enable_if_t, T> shfl_xor(T val, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ +#if CUDART_VERSION >= 9000 + return __shfl_xor_sync(mask, val, laneMask, width); +#else + return __shfl_xor(val, laneMask, width); +#endif +} + +/// Overload of shfl_xor for data types not supported by the CUDA intrinsics +template +DI std::enable_if_t, T> shfl_xor(T val, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + using UnitT = + std::conditional_t, + unsigned int, + std::conditional_t, unsigned short, unsigned char>>; + + constexpr int n_words = sizeof(T) / sizeof(UnitT); + + T output; + UnitT* output_alias = reinterpret_cast(&output); + UnitT* input_alias = reinterpret_cast(&val); + + unsigned int shuffle_word; + shuffle_word = shfl_xor((unsigned int)input_alias[0], laneMask, width, mask); + output_alias[0] = shuffle_word; + +#pragma unroll + for (int i = 1; i < n_words; ++i) { + shuffle_word = shfl_xor((unsigned int)input_alias[i], laneMask, width, mask); + output_alias[i] = shuffle_word; + } + + return output; +} + +} // namespace raft \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 88ad7772c2..98ce8ac5bd 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -372,7 +372,15 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME UTILS_TEST PATH test/core/seive.cu test/util/bitonic_sort.cu test/util/cudart_utils.cpp - test/util/device_atomics.cu test/util/integer_utils.cpp test/util/pow2_utils.cu + NAME + UTILS_TEST + PATH + test/core/seive.cu + test/util/bitonic_sort.cu + test/util/cudart_utils.cpp + test/util/device_atomics.cu + test/util/integer_utils.cpp + test/util/pow2_utils.cu + test/util/reduction.cu ) endif() diff --git a/cpp/test/util/reduction.cu b/cpp/test/util/reduction.cu new file mode 100644 index 0000000000..17deaf99eb --- /dev/null +++ b/cpp/test/util/reduction.cu @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include + +#include +#include + +#include + +#include +#include + +namespace raft::util { + +constexpr int max_warps_per_block = 32; + +template +__global__ void test_reduction_kernel(const int* input, int* reduction_res, ReduceLambda reduce_op) +{ + assert(gridDim.x == 1); + __shared__ int red_buf[max_warps_per_block]; + int th_val = input[threadIdx.x]; + th_val = raft::blockReduce(th_val, (char*)red_buf, reduce_op); + if (threadIdx.x == 0) { reduction_res[0] = th_val; } +} + +template +__global__ void test_ranked_reduction_kernel(const int* input, + int* reduction_res, + int* out_rank, + ReduceLambda reduce_op) +{ + assert(gridDim.x == 1); + __shared__ int red_buf[2 * max_warps_per_block]; + int th_val = input[threadIdx.x]; + int th_rank = threadIdx.x; + auto result = raft::blockRankedReduce(th_val, red_buf, th_rank, reduce_op); + if (threadIdx.x == 0) { + reduction_res[0] = result.first; + out_rank[0] = result.second; + } +} + +__global__ void test_block_random_sample_kernel(const int* input, int* reduction_res) +{ + assert(gridDim.x == 1); + __shared__ int red_buf[2 * max_warps_per_block]; + raft::random::PCGenerator thread_rng(1234, threadIdx.x, 0); + int th_val = input[threadIdx.x]; + int th_rank = threadIdx.x; + int result = raft::random::device::block_random_sample(thread_rng, red_buf, th_val, th_rank); + if (threadIdx.x == 0) { reduction_res[0] = result; } +} + +template +__global__ void test_binary_reduction_kernel(const int* input, int* reduction_res) +{ + assert(gridDim.x == 1); + __shared__ int shared[TPB / WarpSize]; + int th_val = input[threadIdx.x]; + int result = raft::binaryBlockReduce(th_val, shared); + if (threadIdx.x == 0) { reduction_res[0] = result; } +} + +struct reduction_launch { + template + static void run(const rmm::device_uvector& arr_d, + int ref_val, + ReduceLambda reduce_op, + rmm::cuda_stream_view stream) + { + rmm::device_scalar ref_d(stream); + const int block_dim = 64; + const int grid_dim = 1; + test_reduction_kernel<<>>( + arr_d.data(), ref_d.data(), reduce_op); + stream.synchronize(); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + ASSERT_EQ(ref_d.value(stream), ref_val); + } + + template + static void run_ranked(const rmm::device_uvector& arr_d, + int ref_val, + int rank_ref_val, + ReduceLambda reduce_op, + rmm::cuda_stream_view stream) + { + rmm::device_scalar ref_d(stream); + rmm::device_scalar rank_d(stream); + const int block_dim = 64; + const int grid_dim = 1; + test_ranked_reduction_kernel<<>>( + arr_d.data(), ref_d.data(), rank_d.data(), reduce_op); + stream.synchronize(); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + ASSERT_EQ(ref_d.value(stream), ref_val); + ASSERT_EQ(rank_d.value(stream), rank_ref_val); + } + + static void run_random_sample(const rmm::device_uvector& arr_d, + int ref_val, + rmm::cuda_stream_view stream) + { + rmm::device_scalar ref_d(stream); + const int block_dim = 64; + const int grid_dim = 1; + test_block_random_sample_kernel<<>>(arr_d.data(), ref_d.data()); + stream.synchronize(); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + ASSERT_EQ(ref_d.value(stream), ref_val); + } + + static void run_binary(const rmm::device_uvector& arr_d, + int ref_val, + rmm::cuda_stream_view stream) + { + rmm::device_scalar ref_d(stream); + constexpr int block_dim = 64; + const int grid_dim = 1; + test_binary_reduction_kernel + <<>>(arr_d.data(), ref_d.data()); + stream.synchronize(); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + ASSERT_EQ(ref_d.value(stream), ref_val); + } +}; + +template +class ReductionTest : public testing::TestWithParam> { // NOLINT + protected: + const std::vector input; // NOLINT + rmm::cuda_stream_view stream; // NOLINT + rmm::device_uvector arr_d; // NOLINT + + public: + explicit ReductionTest() + : input(testing::TestWithParam>::GetParam()), + stream(rmm::cuda_stream_default), + arr_d(input.size(), stream) + { + update_device(arr_d.data(), input.data(), input.size(), stream); + } + + void run_reduction() + { + // calculate the results + reduction_launch::run(arr_d, 0, raft::min_op{}, stream); + reduction_launch::run(arr_d, 5, raft::max_op{}, stream); + reduction_launch::run(arr_d, 158, raft::add_op{}, stream); + reduction_launch::run_ranked(arr_d, 5, 15, raft::max_op{}, stream); + reduction_launch::run_ranked(arr_d, 0, 26, raft::min_op{}, stream); + // value 15 is for the current state of PCgenerator. adjust this if rng changes + reduction_launch::run_random_sample(arr_d, 15, stream); + } + + void run_binary_reduction() { reduction_launch::run_binary(arr_d, 24, stream); } +}; + +const std::vector test_vector{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 5, 1, 2, 3, 4, 1, 2, + 3, 4, 1, 2, 0, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}; +const std::vector binary_test_vector{ + 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0}; +auto reduction_input = ::testing::Values(test_vector); +auto binary_reduction_input = ::testing::Values(binary_test_vector); + +using ReductionTestInt = ReductionTest; // NOLINT +using BinaryReductionTestInt = ReductionTest; // NOLINT +TEST_P(ReductionTestInt, REDUCTIONS) { run_reduction(); } +INSTANTIATE_TEST_CASE_P(ReductionTest, ReductionTestInt, reduction_input); // NOLINT +TEST_P(BinaryReductionTestInt, BINARY_REDUCTION) { run_binary_reduction(); } // NOLINT +INSTANTIATE_TEST_CASE_P(BinaryReductionTest, + BinaryReductionTestInt, + binary_reduction_input); // NOLINT + +} // namespace raft::util