From ad0c1c1aa36726627b6c4a67b00a368ce9e89217 Mon Sep 17 00:00:00 2001 From: Nicolas Blin <31096601+Kh4ster@users.noreply.github.com> Date: Sat, 10 Jun 2023 05:38:10 +0200 Subject: [PATCH] This PR adds support to __half and nb_bfloat16 to myAtomicReduce (#1585) Authors: - Nicolas Blin (https://github.com/Kh4ster) Approvers: - Louis Sugy (https://github.com/Nyrio) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1585 --- cpp/include/raft/util/cuda_utils.cuh | 34 ++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 0523dcc81c..e718ca3545 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -20,6 +20,11 @@ #include #include +#if defined(_RAFT_HAS_CUDA) +#include +#include +#endif + #include #include #include @@ -79,6 +84,35 @@ DI void myAtomicReduce(float* address, float val, ReduceLambda op) } while (assumed != old); } +// Needed for atomicCas on ushort +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) +template +DI void myAtomicReduce(__half* address, __half val, ReduceLambda op) +{ + unsigned short int* address_as_uint = (unsigned short int*)address; + unsigned short int old = *address_as_uint, assumed; + do { + assumed = old; + old = atomicCAS(address_as_uint, assumed, __half_as_ushort(op(val, __ushort_as_half(assumed)))); + } while (assumed != old); +} +#endif + +// Needed for nv_bfloat16 support +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +template +DI void myAtomicReduce(nv_bfloat16* address, nv_bfloat16 val, ReduceLambda op) +{ + unsigned short int* address_as_uint = (unsigned short int*)address; + unsigned short int old = *address_as_uint, assumed; + do { + assumed = old; + old = atomicCAS( + address_as_uint, assumed, __bfloat16_as_ushort(op(val, __ushort_as_bfloat16(assumed)))); + } while (assumed != old); +} +#endif + template DI void myAtomicReduce(int* address, int val, ReduceLambda op) {