Skip to content

Commit

Permalink
Specialize cuda::std::numeric_limits for FP8 types (#3478) (#3492)
Browse files Browse the repository at this point in the history
Co-authored-by: David Bayer <[email protected]>
  • Loading branch information
bernhardmgruber and davebayer authored Jan 30, 2025
1 parent 863b25f commit 3a594e3
Show file tree
Hide file tree
Showing 34 changed files with 627 additions and 150 deletions.
192 changes: 192 additions & 0 deletions libcudacxx/include/cuda/std/limits
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@
#include <cuda/std/climits>
#include <cuda/std/version>

#if defined(_LIBCUDACXX_HAS_NVFP16)
# include <cuda_fp16.h>
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
# include <cuda_bf16.h>
_CCCL_DIAG_POP
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
#endif // _CCCL_HAS_NVFP8()

_CCCL_PUSH_MACROS

_LIBCUDACXX_BEGIN_NAMESPACE_STD
Expand Down Expand Up @@ -744,6 +759,183 @@ public:
};
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# if defined(_CCCL_BUILTIN_BIT_CAST) || _CCCL_STD_VER >= 2014
# define _LIBCUDACXX_CONSTEXPR_FP8_LIMITS constexpr
# else // ^^^ _CCCL_BUILTIN_BIT_CAST || _CCCL_STD_VER >= 2014 ^^^ // vvv !_CCCL_BUILTIN_BIT_CAST && _CCCL_STD_VER <
// 2014 vvv
# define _LIBCUDACXX_CONSTEXPR_FP8_LIMITS
# endif // ^^^ !_CCCL_BUILTIN_BIT_CAST && _CCCL_STD_VER < 2014 ^^^

template <>
class __numeric_limits_impl<__nv_fp8_e4m3, __numeric_limits_type::__floating_point>
{
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS __nv_fp8_e4m3 __make_value(__nv_fp8_storage_t __val)
{
# if defined(_CCCL_BUILTIN_BIT_CAST)
return _CUDA_VSTD::bit_cast<__nv_fp8_e4m3>(__val);
# else // ^^^ _CCCL_BUILTIN_BIT_CAST ^^^ // vvv !_CCCL_BUILTIN_BIT_CAST vvv
__nv_fp8_e4m3 __ret{};
__ret.__x = __val;
return __ret;
# endif // ^^^ !_CCCL_BUILTIN_BIT_CAST ^^^
}

public:
using type = __nv_fp8_e4m3;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 3;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x08u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type max() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7eu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type lowest() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0xfeu));
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type epsilon() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x20u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type round_error() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x30u));
}

static constexpr int min_exponent = -6;
static constexpr int min_exponent10 = -2;
static constexpr int max_exponent = 8;
static constexpr int max_exponent10 = 2;

static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type infinity() noexcept
{
return type{};
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type quiet_NaN() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7fu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type signaling_NaN() noexcept
{
return type{};
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type denorm_min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x01u));
}

static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};

template <>
class __numeric_limits_impl<__nv_fp8_e5m2, __numeric_limits_type::__floating_point>
{
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS __nv_fp8_e5m2 __make_value(__nv_fp8_storage_t __val)
{
# if defined(_CCCL_BUILTIN_BIT_CAST)
return _CUDA_VSTD::bit_cast<__nv_fp8_e5m2>(__val);
# else // ^^^ _CCCL_BUILTIN_BIT_CAST ^^^ // vvv !_CCCL_BUILTIN_BIT_CAST vvv
__nv_fp8_e5m2 __ret{};
__ret.__x = __val;
return __ret;
# endif // ^^^ !_CCCL_BUILTIN_BIT_CAST ^^^
}

public:
using type = __nv_fp8_e5m2;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 2;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x04u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type max() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7bu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type lowest() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0xfbu));
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type epsilon() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x34u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type round_error() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x38u));
}

static constexpr int min_exponent = -15;
static constexpr int min_exponent10 = -5;
static constexpr int max_exponent = 15;
static constexpr int max_exponent10 = 4;

static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type infinity() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7cu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type quiet_NaN() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7eu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type signaling_NaN() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7du));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type denorm_min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x01u));
}

static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _CCCL_HAS_NVFP8()

template <class _Tp>
class numeric_limits : public __numeric_limits_impl<_Tp>
{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define NUMERIC_LIMITS_MEMBERS_COMMON_H

// Disable all the extended floating point operations and conversions
#define __CUDA_NO_FP8_CONVERSIONS__ 1
#define __CUDA_NO_HALF_CONVERSIONS__ 1
#define __CUDA_NO_HALF_OPERATORS__ 1
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
Expand All @@ -24,6 +25,32 @@ __host__ __device__ bool float_eq(T x, T y)
return x == y;
}

#if _CCCL_HAS_NVFP8()
__host__ __device__ inline __nv_fp8_e4m3 make_fp8_e4m3(double x, __nv_saturation_t sat = __NV_NOSAT)
{
__nv_fp8_e4m3 res;
res.__x = __nv_cvt_double_to_fp8(x, sat, __NV_E4M3);
return res;
}

__host__ __device__ inline __nv_fp8_e5m2 make_fp8_e5m2(double x, __nv_saturation_t sat = __NV_NOSAT)
{
__nv_fp8_e5m2 res;
res.__x = __nv_cvt_double_to_fp8(x, sat, __NV_E5M2);
return res;
}

__host__ __device__ inline bool float_eq(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y)
{
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E4M3)});
}

__host__ __device__ inline bool float_eq(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y)
{
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E5M2)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E5M2)});
}
#endif // _CCCL_HAS_NVFP8

#if defined(_LIBCUDACXX_HAS_NVFP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test_type<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_type<__nv_fp8_e4m3>();
test_type<__nv_fp8_e5m2>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(9.18354961579912115600575419705e-41));
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.001953125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.0000152587890625));
#endif // _CCCL_HAS_NVFP8()
#if !defined(__FLT_DENORM_MIN__) && !defined(FLT_TRUE_MIN)
# error Test has no expected values for floating point types
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16, 8>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, 3>();
test<__nv_fp8_e5m2, 2>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,25 @@

#include <cuda/std/cfloat>
#include <cuda/std/limits>
#include <cuda/std/type_traits>

#include "test_macros.h"

template <class T, int expected>
template <class T, cuda::std::enable_if_t<cuda::std::is_integral<T>::value, int> = 0>
__host__ __device__ constexpr int make_expected_digits10()
{
// digits * log10(2)
return static_cast<int>((cuda::std::numeric_limits<T>::digits * 30103l) / 100000l);
}

template <class T, cuda::std::enable_if_t<!cuda::std::is_integral<T>::value, int> = 0>
__host__ __device__ constexpr int make_expected_digits10()
{
// (digits - 1) * log10(2)
return static_cast<int>(((cuda::std::numeric_limits<T>::digits - 1) * 30103l) / 100000l);
}

template <class T, int expected = make_expected_digits10<T>()>
__host__ __device__ void test()
{
static_assert(cuda::std::numeric_limits<T>::digits10 == expected, "digits10 test 1");
Expand All @@ -30,41 +45,45 @@ __host__ __device__ void test()

int main(int, char**)
{
test<bool, 0>();
test<char, 2>();
test<signed char, 2>();
test<unsigned char, 2>();
test<wchar_t, 5 * sizeof(wchar_t) / 2 - 1>(); // 4 -> 9 and 2 -> 4
test<bool>();
test<char>();
test<signed char>();
test<unsigned char>();
test<wchar_t>();
#if TEST_STD_VER > 2017 && defined(__cpp_char8_t)
test<char8_t, 2>();
test<char8_t>();
#endif
#ifndef _LIBCUDACXX_HAS_NO_UNICODE_CHARS
test<char16_t, 4>();
test<char32_t, 9>();
test<char16_t>();
test<char32_t>();
#endif // _LIBCUDACXX_HAS_NO_UNICODE_CHARS
test<short, 4>();
test<unsigned short, 4>();
test<int, 9>();
test<unsigned int, 9>();
test<long, sizeof(long) == 4 ? 9 : 18>();
test<unsigned long, sizeof(long) == 4 ? 9 : 19>();
test<long long, 18>();
test<unsigned long long, 19>();
test<short>();
test<unsigned short>();
test<int>();
test<unsigned int>();
test<long>();
test<unsigned long>();
test<long long>();
test<unsigned long long>();
#ifndef _LIBCUDACXX_HAS_NO_INT128
test<__int128_t, 38>();
test<__uint128_t, 38>();
test<__int128_t>();
test<__uint128_t>();
#endif
test<float, FLT_DIG>();
test<double, DBL_DIG>();
test<float>();
test<double>();
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, LDBL_DIG>();
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
test<__half, 3>();
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16, 2>();
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>();
test<__nv_fp8_e5m2>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(0.0078125));
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.25));
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16, cuda::std::denorm_present>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, cuda::std::denorm_present>();
test<__nv_fp8_e5m2, cuda::std::denorm_present>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Loading

0 comments on commit 3a594e3

Please sign in to comment.