Skip to content

Commit

Permalink
Add numeric_limits for MLFloat16 and BFloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Sep 24, 2024
1 parent b636b27 commit 718e05b
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 33 deletions.
70 changes: 37 additions & 33 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ internal static int LeadingZeroCount(uint num)
/// <summary>
/// Extracts single precision number bit representation as uint
/// so its bits can be manipulated.
///
///
/// This API is the reverse of UInt32BitsToSingle().
///
///
/// </summary>
/// <param name="single">float value</param>
/// <returns></returns>
Expand All @@ -79,11 +79,11 @@ internal static uint SingleToUInt32Bits(float single)
/// <summary>
/// Needed because BitConverter impl is not available until
/// later versions. This API is the reverse of SingleToUInt32Bits().
///
///
/// For the exact bit representation of float see IEEE 754 standard for single precision.
///
///
/// </summary>
/// <param name="singleBits">bit representation of float either obtained from
/// <param name="singleBits">bit representation of float either obtained from
/// SingleToUInt32Bits or assembled using bitwise operators</param>
/// <returns></returns>
internal static float UInt32BitsToSingle(uint singleBits)
Expand All @@ -99,7 +99,7 @@ internal static float UInt32BitsToSingle(uint singleBits)
/// <summary>
/// Converts single precision bits representation which can be obtained using
/// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard.
///
///
/// </summary>
/// <param name="singleBits">bits representation of a single precision number (float)</param>
/// <returns></returns>
Expand Down Expand Up @@ -177,8 +177,8 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
/// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
///
/// The implementation is derived from
///
/// The implementation is derived from
/// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974
/// </summary>
[StructLayout(LayoutKind.Sequential)]
Expand Down Expand Up @@ -215,6 +215,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)

private const ushort OneBits = 0x3C00;

// Minimum positive normalized value. It is corresponding to numeric_limits<float16>::min() in C++.
private const ushort EpsilonBits = 0x0400;

private const ushort PositiveInfinityBits = 0x7C00;
Expand All @@ -238,7 +239,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
/// <summary>
/// Float16 Epsilon value
/// </summary>
public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08
public static Float16 Epsilon => new Float16(EpsilonBits); // 0.00006103515625

/// <summary>
/// Float16 Pi value
Expand All @@ -248,17 +249,17 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
/// <summary>
/// Float16 Positive Infinity value
/// </summary>
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0;
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits);

/// <summary>
/// Float16 Negative Infinity value
/// </summary>
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits);

/// <summary>
/// Float16 NaN
/// </summary>
public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0
public static Float16 NaN => new Float16(PositiveQNaNBits); // quiet NaN

/// <summary>
/// Float16 Zero value
Expand All @@ -276,14 +277,14 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0

/// <summary>
/// Float16 Min value
/// Float16 Lowest value
/// </summary>
public static Float16 MinValue => new Float16(MinValueBits); // 64,511
public static Float16 MinValue => new Float16(MinValueBits); // -65504.0

/// <summary>
/// Float16 Max value
/// </summary>
public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743
public static Float16 MaxValue => new Float16(MaxValueBits); // 65504.0

/// <summary>
/// float16 representation bits
Expand Down Expand Up @@ -348,7 +349,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)

/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -376,7 +377,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)

/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand All @@ -388,7 +389,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)

/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -429,7 +430,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
/// <summary>
/// Compares values of two Float16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -479,7 +480,7 @@ public static bool IsInfinity(Float16 value)
/// <summary>
/// Determines whether the specified value is NaN.
/// </summary>
///
///
/// <param name="value">Float16 instance</param>
/// <returns>true if the value is not a number</returns>
public static bool IsNaN(Float16 value)
Expand All @@ -500,7 +501,7 @@ public static bool IsNegative(Float16 value)
/// <summary>
/// Determines whether the specified value is negative infinity.
/// </summary>
///
///
/// <param name="value">Float16 instance</param>
/// <returns>true if the value is negative infinity</returns>
public static bool IsNegativeInfinity(Float16 value)
Expand Down Expand Up @@ -549,7 +550,7 @@ public static bool IsSubnormal(Float16 value)
/// <summary>
/// Compares this object to another object, returning an integer that indicates the relationship.
/// </summary>
///
///
/// <param name="obj">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
Expand All @@ -570,7 +571,7 @@ public int CompareTo(object obj)
/// </summary>
/// <param name="other">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
public int CompareTo(Float16 other)
{
Expand Down Expand Up @@ -864,10 +865,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
private const ushort PositiveQNaNBits = 0x7FC1;
private const ushort NegativeQNaNBits = 0xFFC1;

// Lowest finite value. It is corresponding to numeric_limits<BFloat16>::lowest() in C++.
private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111

private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111

private const ushort EpsilonBits = 0x0080; // the smallest positive normal value
// Minimum positive normalized value. It is corresponding to numeric_limits<BFloat16>::min() in C++.
private const ushort EpsilonBits = 0x0080;

private const ushort PiBits = 0x4049; // 0b0_10000000_1001001

Expand Down Expand Up @@ -899,7 +903,7 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
/// <summary>
/// BFloat16 NaN
/// </summary>
public static BFloat16 NaN => new BFloat16(NegativeQNaNBits);
public static BFloat16 NaN => new BFloat16(PositiveQNaNBits); // quiet NaN

/// <summary>
/// BFloat16 Positive Zero
Expand All @@ -919,13 +923,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
/// <summary>
/// BFloat16 Min value
/// </summary>
public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407
public static BFloat16 MinValue => new BFloat16(MinValueBits); // -3.38953139e38

/// <summary>
/// BFloat16 Max value
/// </summary>

public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639
public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 3.38953139e38

/// <summary>
/// bfloat16 representation bits
Expand Down Expand Up @@ -1051,7 +1055,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
/// <summary>
/// Compares values of two BFloat16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -1102,7 +1106,7 @@ public static bool IsInfinity(BFloat16 value)
/// <summary>
/// Determines whether the specified value is NaN.
/// </summary>
///
///
/// <param name="value">BFloat16 instance</param>
/// <returns>true if the value is not a number</returns>
public static bool IsNaN(BFloat16 value)
Expand All @@ -1123,7 +1127,7 @@ public static bool IsNegative(BFloat16 value)
/// <summary>
/// Determines whether the specified value is negative infinity.
/// </summary>
///
///
/// <param name="value">BFloat16 instance</param>
/// <returns>true if the value is negative infinity</returns>
public static bool IsNegativeInfinity(BFloat16 value)
Expand Down Expand Up @@ -1170,7 +1174,7 @@ public static bool IsSubnormal(BFloat16 value)
/// <summary>
/// Compares this object to another object, returning an integer that indicates the relationship.
/// </summary>
///
///
/// <param name="obj">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
Expand All @@ -1191,7 +1195,7 @@ public int CompareTo(object obj)
/// </summary>
/// <param name="other">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
public int CompareTo(BFloat16 other)
{
Expand Down Expand Up @@ -1368,4 +1372,4 @@ private static uint StripSign(BFloat16 value)

#endregion
}
}
}
Loading

0 comments on commit 718e05b

Please sign in to comment.