diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp index 56a8d78926..809b2948e7 100644 --- a/cpp/include/raft/core/math.hpp +++ b/cpp/include/raft/core/math.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,12 +49,42 @@ RAFT_INLINE_FUNCTION auto abs(T x) template constexpr RAFT_INLINE_FUNCTION auto abs(T x) -> std::enable_if_t && !std::is_same_v && +#if defined(_RAFT_HAS_CUDA) + !std::is_same_v<__half, T> && !std::is_same_v && +#endif !std::is_same_v && !std::is_same_v && !std::is_same_v, T> { return x < T{0} ? -x : x; } +#if defined(_RAFT_HAS_CUDA) +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, __half> abs(T x) +{ +#if (__CUDA_ARCH__ >= 530) + return ::__habs(x); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530"); + return T{}; +#endif +} + +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, nv_bfloat16> +abs(T x) +{ +#if (__CUDA_ARCH__ >= 800) + return ::__habs(x); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800"); + return T{}; +#endif +} +#endif +/** @} */ /** Inverse cosine */ template diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index e718ca3545..bf46e069e4 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,12 @@ #pragma once +#include +#include #include #include #include -#if defined(_RAFT_HAS_CUDA) -#include -#include -#endif - #include #include #include @@ -278,17 +275,53 @@ template <> * @{ */ template -inline __device__ T myInf(); -template <> -inline __device__ float myInf() +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, float> myInf() { return CUDART_INF_F; } -template <> -inline __device__ double myInf() +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, double> myInf() { return CUDART_INF; } +// Half/Bfloat constants only defined after CUDA 12.2 +#if __CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2) +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, __half> myInf() +{ +#if (__CUDA_ARCH__ >= 530) + return __ushort_as_half((unsigned short)0x7C00U); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530"); + return T{}; +#endif +} +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, nv_bfloat16> +myInf() +{ +#if (__CUDA_ARCH__ >= 800) + return __ushort_as_bfloat16((unsigned short)0x7F80U); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800"); + return T{}; +#endif +} +#else +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, __half> myInf() +{ + return CUDART_INF_FP16; +} +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, nv_bfloat16> +myInf() +{ + return CUDART_INF_BF16; +} +#endif /** @} */ /**