From 718e05b21845a1353b32006668ff8683e36b0365 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Sep 2024 08:26:30 +0000 Subject: [PATCH] Add numeric_limits for MLFloat16 and BFloat16 --- .../OrtFloat16.shared.cs | 70 +++++---- include/onnxruntime/core/framework/float16.h | 142 ++++++++++++++++++ onnxruntime/test/framework/data_types_test.cc | 33 ++++ 3 files changed, 212 insertions(+), 33 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs index 7c22e1b213b41..c71218ba62f7c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs @@ -60,9 +60,9 @@ internal static int LeadingZeroCount(uint num) /// /// Extracts single precision number bit representation as uint /// so its bits can be manipulated. - /// + /// /// This API is the reverse of UInt32BitsToSingle(). - /// + /// /// /// float value /// @@ -79,11 +79,11 @@ internal static uint SingleToUInt32Bits(float single) /// /// Needed because BitConverter impl is not available until /// later versions. This API is the reverse of SingleToUInt32Bits(). - /// + /// /// For the exact bit representation of float see IEEE 754 standard for single precision. - /// + /// /// - /// bit representation of float either obtained from + /// bit representation of float either obtained from /// SingleToUInt32Bits or assembled using bitwise operators /// internal static float UInt32BitsToSingle(uint singleBits) @@ -99,7 +99,7 @@ internal static float UInt32BitsToSingle(uint singleBits) /// /// Converts single precision bits representation which can be obtained using /// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard. - /// + /// /// /// bits representation of a single precision number (float) /// @@ -177,8 +177,8 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) /// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus, /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library. /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching. - /// - /// The implementation is derived from + /// + /// The implementation is derived from /// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974 /// [StructLayout(LayoutKind.Sequential)] @@ -215,6 +215,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) private const ushort OneBits = 0x3C00; + // Minimum positive normalized value. It is corresponding to numeric_limits::min() in C++. private const ushort EpsilonBits = 0x0400; private const ushort PositiveInfinityBits = 0x7C00; @@ -238,7 +239,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) /// /// Float16 Epsilon value /// - public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08 + public static Float16 Epsilon => new Float16(EpsilonBits); // 0.00006103515625 /// /// Float16 Pi value @@ -248,17 +249,17 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) /// /// Float16 Positive Infinity value /// - public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0; + public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); /// /// Float16 Negative Infinity value /// - public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0 + public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); /// /// Float16 NaN /// - public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0 + public static Float16 NaN => new Float16(PositiveQNaNBits); // quiet NaN /// /// Float16 Zero value @@ -276,14 +277,14 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0 /// - /// Float16 Min value + /// Float16 Lowest value /// - public static Float16 MinValue => new Float16(MinValueBits); // 64,511 + public static Float16 MinValue => new Float16(MinValueBits); // -65504.0 /// /// Float16 Max value /// - public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743 + public static Float16 MaxValue => new Float16(MaxValueBits); // 65504.0 /// /// float16 representation bits @@ -348,7 +349,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 - /// + /// /// /// left hand side /// right hand side @@ -376,7 +377,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 - /// + /// /// /// left hand side /// right hand side @@ -388,7 +389,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 - /// + /// /// /// left hand side /// right hand side @@ -429,7 +430,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 for binary equality. /// If either of the values is NaN, this will return false. - /// + /// /// /// left hand side /// right hand side @@ -479,7 +480,7 @@ public static bool IsInfinity(Float16 value) /// /// Determines whether the specified value is NaN. /// - /// + /// /// Float16 instance /// true if the value is not a number public static bool IsNaN(Float16 value) @@ -500,7 +501,7 @@ public static bool IsNegative(Float16 value) /// /// Determines whether the specified value is negative infinity. /// - /// + /// /// Float16 instance /// true if the value is negative infinity public static bool IsNegativeInfinity(Float16 value) @@ -549,7 +550,7 @@ public static bool IsSubnormal(Float16 value) /// /// Compares this object to another object, returning an integer that indicates the relationship. /// - /// + /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , or a value greater than zero @@ -570,7 +571,7 @@ public int CompareTo(object obj) /// /// Object to compare to /// A value less than zero if this is less than , - /// zero if this is equal to , + /// zero if this is equal to , /// or a value greater than zero if this is greater than . public int CompareTo(Float16 other) { @@ -864,10 +865,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig) private const ushort PositiveQNaNBits = 0x7FC1; private const ushort NegativeQNaNBits = 0xFFC1; + // Lowest finite value. It is corresponding to numeric_limits::lowest() in C++. private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111 + private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111 - private const ushort EpsilonBits = 0x0080; // the smallest positive normal value + // Minimum positive normalized value. It is corresponding to numeric_limits::min() in C++. + private const ushort EpsilonBits = 0x0080; private const ushort PiBits = 0x4049; // 0b0_10000000_1001001 @@ -899,7 +903,7 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig) /// /// BFloat16 NaN /// - public static BFloat16 NaN => new BFloat16(NegativeQNaNBits); + public static BFloat16 NaN => new BFloat16(PositiveQNaNBits); // quiet NaN /// /// BFloat16 Positive Zero @@ -919,13 +923,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig) /// /// BFloat16 Min value /// - public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407 + public static BFloat16 MinValue => new BFloat16(MinValueBits); // -3.38953139e38 /// /// BFloat16 Max value /// - public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639 + public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 3.38953139e38 /// /// bfloat16 representation bits @@ -1051,7 +1055,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two BFloat16 for binary equality. /// If either of the values is NaN, this will return false. - /// + /// /// /// left hand side /// right hand side @@ -1102,7 +1106,7 @@ public static bool IsInfinity(BFloat16 value) /// /// Determines whether the specified value is NaN. /// - /// + /// /// BFloat16 instance /// true if the value is not a number public static bool IsNaN(BFloat16 value) @@ -1123,7 +1127,7 @@ public static bool IsNegative(BFloat16 value) /// /// Determines whether the specified value is negative infinity. /// - /// + /// /// BFloat16 instance /// true if the value is negative infinity public static bool IsNegativeInfinity(BFloat16 value) @@ -1170,7 +1174,7 @@ public static bool IsSubnormal(BFloat16 value) /// /// Compares this object to another object, returning an integer that indicates the relationship. /// - /// + /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , or a value greater than zero @@ -1191,7 +1195,7 @@ public int CompareTo(object obj) /// /// Object to compare to /// A value less than zero if this is less than , - /// zero if this is equal to , + /// zero if this is equal to , /// or a value greater than zero if this is greater than . public int CompareTo(BFloat16 other) { @@ -1368,4 +1372,4 @@ private static uint StripSign(BFloat16 value) #endregion } -} \ No newline at end of file +} diff --git a/include/onnxruntime/core/framework/float16.h b/include/onnxruntime/core/framework/float16.h index 1f2f175c6e691..64ffe3ccdfc46 100644 --- a/include/onnxruntime/core/framework/float16.h +++ b/include/onnxruntime/core/framework/float16.h @@ -295,3 +295,145 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) { } } // namespace onnxruntime + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::MLFloat16 min() noexcept { + return onnxruntime::MLFloat16::FromBits(0x0400U); // Minimum positive normalized value: 0.00006103515625 + } + + static constexpr onnxruntime::MLFloat16 max() noexcept { + return onnxruntime::MLFloat16::FromBits(0x7BFFU); // Largest representable value: 65504 + } + + static constexpr onnxruntime::MLFloat16 lowest() noexcept { + return onnxruntime::MLFloat16::FromBits(0xFBFFU); // Smallest representable value: -65504 + } + + static constexpr onnxruntime::MLFloat16 infinity() noexcept { + return onnxruntime::MLFloat16::FromBits(0x7C00U); // Bits: sign(0), exponent(111,11), fraction(00,0000,0000) + } + + static constexpr onnxruntime::MLFloat16 quiet_NaN() noexcept { + return onnxruntime::MLFloat16::FromBits(0x7E00U); // Bits: sign(0), exponent(111,11), fraction(10,0000,0000) + } + + static constexpr onnxruntime::MLFloat16 signaling_NaN() noexcept { + return onnxruntime::MLFloat16::FromBits(0x7D00U); // Bits: sign(0), exponent(111,11), fraction(01,0000,0000) + } + + static constexpr onnxruntime::MLFloat16 denorm_min() noexcept { + return onnxruntime::MLFloat16::FromBits(0x0001U); // Minimum subnormal value: 0.000000059604645 + } + + static constexpr onnxruntime::MLFloat16 epsilon() noexcept { + return onnxruntime::MLFloat16::FromBits(0x1400U); // Difference between 1.0 and the next value: 2^-10 = 0.0009765625 + } + + static constexpr onnxruntime::MLFloat16 round_error() noexcept { + return onnxruntime::MLFloat16::FromBits(0x3800U); // 0.5 + } + + static constexpr bool is_specialized = true; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_present; + static constexpr bool has_denorm_loss = false; + + static constexpr bool is_bounded = true; + static constexpr bool is_iec559 = true; + static constexpr bool is_modulo = false; + + static constexpr int digits = 11; // Number of significant digits (mantissa) + static constexpr int digits10 = 3; // Decimal digits of precision + static constexpr int max_digits10 = 5; // Max decimal digits required for precision + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr std::float_round_style round_style = std::round_to_nearest; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::BFloat16 min() noexcept { + return onnxruntime::BFloat16::FromBits(0x0080U); // Minimum positive normalized value: 1.175494e-38 + } + + static constexpr onnxruntime::BFloat16 max() noexcept { + return onnxruntime::BFloat16::FromBits(0x7F7FU); // Largest representable value: 3.38953139e38 + } + + static constexpr onnxruntime::BFloat16 lowest() noexcept { + return onnxruntime::BFloat16::FromBits(0xFF7FU); // Smallest representable value: -3.38953139e38 + } + + static constexpr onnxruntime::BFloat16 infinity() noexcept { + return onnxruntime::BFloat16::FromBits(0x7F80U); // Bits: sign(0), exponent(111,1111,1), fraction(000,0000) + } + + static constexpr onnxruntime::BFloat16 quiet_NaN() noexcept { + // The most significant fraction bit shall be 1, and no limitation on other fraction bits. + // Note that Torch, Tensorflow, OpenVino, nGraph uses 0x7FC0; Paddle uses 0x7FC1; CUDA uses 0x7FFF. + return onnxruntime::BFloat16::FromBits(0x7FC1U); // Bits: sign(0), exponent(111,1111,1), fraction(100,0001) + } + + static constexpr onnxruntime::BFloat16 signaling_NaN() noexcept { + // The most significant fraction bit shall be 0, and there is at least one 1 in other fraction bits. + return onnxruntime::BFloat16::FromBits(0x7F81U); // Bits: sign(0), exponent(111,1111,1), fraction(000,0001) + } + + static constexpr onnxruntime::BFloat16 denorm_min() noexcept { + return onnxruntime::BFloat16::FromBits(0x0001U); // Minimum subnormal value: 9.1835e-41 + } + + static constexpr onnxruntime::BFloat16 epsilon() noexcept { + return onnxruntime::BFloat16::FromBits(0x3C00U); // Difference between 1.0 and the next value: 2^-7 = 0.0078125 + } + + static constexpr onnxruntime::BFloat16 round_error() noexcept { + return onnxruntime::BFloat16::FromBits(0x3F00U); // 0.5 + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_present; + static constexpr bool has_denorm_loss = false; + + static constexpr bool is_bounded = true; + static constexpr bool is_iec559 = false; + static constexpr bool is_modulo = false; + + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std diff --git a/onnxruntime/test/framework/data_types_test.cc b/onnxruntime/test/framework/data_types_test.cc index 871b255831029..6947d2a3995ba 100644 --- a/onnxruntime/test/framework/data_types_test.cc +++ b/onnxruntime/test/framework/data_types_test.cc @@ -494,6 +494,18 @@ TEST_F(DataTypeTest, MLFloat16Comparision) { } TEST_F(DataTypeTest, MLFloat16TestNAN) { + const MLFloat16 qNAN = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(qNAN.IsNaN()); + EXPECT_TRUE(qNAN.IsNaNOrZero()); + EXPECT_NE(MLFloat16::NaN, qNAN); // NaN are not equal to each other + EXPECT_TRUE(std::isnan(qNAN.ToFloat())); + + const MLFloat16 sNAN = std::numeric_limits::signaling_NaN(); + EXPECT_TRUE(sNAN.IsNaN()); + EXPECT_TRUE(sNAN.IsNaNOrZero()); + EXPECT_NE(MLFloat16::NaN, sNAN); // NaN are not equal to each other + EXPECT_TRUE(std::isnan(sNAN.ToFloat())); + const MLFloat16 fp16NANFromSingle(std::numeric_limits::quiet_NaN()); EXPECT_TRUE(fp16NANFromSingle.IsNaN()); EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero()); @@ -520,6 +532,11 @@ TEST_F(DataTypeTest, MLFloat16NaNComparision) { } TEST_F(DataTypeTest, MLFloat16Infinity) { + const MLFloat16 fp16_infinity(std::numeric_limits::infinity()); + EXPECT_TRUE(fp16_infinity.IsInfinity()); + EXPECT_FALSE(fp16_infinity.IsFinite()); + EXPECT_FALSE(fp16_infinity.IsNegative()); + EXPECT_FALSE(MLFloat16::MaxValue.Negate().IsInfinity()); EXPECT_FALSE(MLFloat16::MaxValue.IsInfinity()); EXPECT_TRUE(MLFloat16::MaxValue.IsFinite()); @@ -550,6 +567,8 @@ TEST_F(DataTypeTest, MLFloat16NormalSubnormal) { EXPECT_TRUE(smallest_subnormal.IsSubnormal()); EXPECT_FALSE(smallest_subnormal.IsNormal()); + EXPECT_EQ(smallest_subnormal, std::numeric_limits::denorm_min()); + // float smallest positive subnormal is ~1.40129846432481707092E-45, and // in float the same number above would be normal const float float_from_smallest_subnormal = static_cast(smallest_subnormal); @@ -639,6 +658,18 @@ TEST_F(DataTypeTest, BFloat16Comparision) { } TEST_F(DataTypeTest, BFloat16TestNAN) { + const BFloat16 qNAN = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(qNAN.IsNaN()); + EXPECT_TRUE(qNAN.IsNaNOrZero()); + EXPECT_NE(BFloat16::NaN, qNAN); + EXPECT_TRUE(std::isnan(qNAN.ToFloat())); + + const BFloat16 sNAN = std::numeric_limits::signaling_NaN(); + EXPECT_TRUE(sNAN.IsNaN()); + EXPECT_TRUE(sNAN.IsNaNOrZero()); + EXPECT_NE(BFloat16::NaN, sNAN); + EXPECT_TRUE(std::isnan(sNAN.ToFloat())); + const BFloat16 fp16NANFromSingle = std::numeric_limits::quiet_NaN(); EXPECT_TRUE(fp16NANFromSingle.IsNaN()); EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero()); @@ -695,6 +726,8 @@ TEST_F(DataTypeTest, BFloat16NormalSubnormal) { EXPECT_TRUE(smallest_subnormal.IsSubnormal()); EXPECT_FALSE(smallest_subnormal.IsNormal()); + EXPECT_EQ(smallest_subnormal, std::numeric_limits::denorm_min()); + const float float_from_smallest_subnormal = (float)smallest_subnormal; EXPECT_FALSE(std::isnormal(float_from_smallest_subnormal));