diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 0523dcc81c..e718ca3545 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -20,6 +20,11 @@ #include #include +#if defined(_RAFT_HAS_CUDA) +#include +#include +#endif + #include #include #include @@ -79,6 +84,35 @@ DI void myAtomicReduce(float* address, float val, ReduceLambda op) } while (assumed != old); } +// Needed for atomicCas on ushort +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) +template +DI void myAtomicReduce(__half* address, __half val, ReduceLambda op) +{ + unsigned short int* address_as_uint = (unsigned short int*)address; + unsigned short int old = *address_as_uint, assumed; + do { + assumed = old; + old = atomicCAS(address_as_uint, assumed, __half_as_ushort(op(val, __ushort_as_half(assumed)))); + } while (assumed != old); +} +#endif + +// Needed for nv_bfloat16 support +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +template +DI void myAtomicReduce(nv_bfloat16* address, nv_bfloat16 val, ReduceLambda op) +{ + unsigned short int* address_as_uint = (unsigned short int*)address; + unsigned short int old = *address_as_uint, assumed; + do { + assumed = old; + old = atomicCAS( + address_as_uint, assumed, __bfloat16_as_ushort(op(val, __ushort_as_bfloat16(assumed)))); + } while (assumed != old); +} +#endif + template DI void myAtomicReduce(int* address, int val, ReduceLambda op) {