Skip to content

Commit

Permalink
add half/bfloat support to myInf and abs (#1592)
Browse files Browse the repository at this point in the history
This PR adds support to __half and nb_bfloat16 to abs and myinf

Authors:
  - Nicolas Blin (https://github.com/Kh4ster)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1592
  • Loading branch information
Kh4ster authored Jan 24, 2024
1 parent 71fce1c commit dbb5e66
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
32 changes: 31 additions & 1 deletion cpp/include/raft/core/math.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,12 +49,42 @@ RAFT_INLINE_FUNCTION auto abs(T x)
template <typename T>
constexpr RAFT_INLINE_FUNCTION auto abs(T x)
-> std::enable_if_t<!std::is_same_v<float, T> && !std::is_same_v<double, T> &&
#if defined(_RAFT_HAS_CUDA)
!std::is_same_v<__half, T> && !std::is_same_v<nv_bfloat16, T> &&
#endif
!std::is_same_v<int, T> && !std::is_same_v<long int, T> &&
!std::is_same_v<long long int, T>,
T>
{
return x < T{0} ? -x : x;
}
#if defined(_RAFT_HAS_CUDA)
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, __half>, __half> abs(T x)
{
#if (__CUDA_ARCH__ >= 530)
return ::__habs(x);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530");
return T{};
#endif
}

template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, nv_bfloat16>, nv_bfloat16>
abs(T x)
{
#if (__CUDA_ARCH__ >= 800)
return ::__habs(x);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800");
return T{};
#endif
}
#endif
/** @} */

/** Inverse cosine */
template <typename T>
Expand Down
55 changes: 44 additions & 11 deletions cpp/include/raft/util/cuda_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,12 @@

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <math_constants.h>
#include <stdint.h>
#include <type_traits>

#if defined(_RAFT_HAS_CUDA)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif

#include <raft/core/cudart_utils.hpp>
#include <raft/core/math.hpp>
#include <raft/core/operators.hpp>
Expand Down Expand Up @@ -278,17 +275,53 @@ template <>
* @{
*/
template <typename T>
inline __device__ T myInf();
template <>
inline __device__ float myInf<float>()
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, float>, float> myInf()
{
return CUDART_INF_F;
}
template <>
inline __device__ double myInf<double>()
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, double>, double> myInf()
{
return CUDART_INF;
}
// Half/Bfloat constants only defined after CUDA 12.2
#if __CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, __half>, __half> myInf()
{
#if (__CUDA_ARCH__ >= 530)
return __ushort_as_half((unsigned short)0x7C00U);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530");
return T{};
#endif
}
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, nv_bfloat16>, nv_bfloat16>
myInf()
{
#if (__CUDA_ARCH__ >= 800)
return __ushort_as_bfloat16((unsigned short)0x7F80U);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800");
return T{};
#endif
}
#else
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, __half>, __half> myInf()
{
return CUDART_INF_FP16;
}
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, nv_bfloat16>, nv_bfloat16>
myInf()
{
return CUDART_INF_BF16;
}
#endif
/** @} */

/**
Expand Down

0 comments on commit dbb5e66

Please sign in to comment.