Skip to content

Commit

Permalink
[rel-1.16.0] Cherry-pick 16940 and 17523 (#17506)
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang authored Sep 14, 2023
1 parent 0772d54 commit 06ea28b
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 97 deletions.
3 changes: 1 addition & 2 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ if (onnxruntime_BUILD_UNIT_TESTS)
FetchContent_Declare(
googletest
URL ${DEP_URL_googletest}
FIND_PACKAGE_ARGS 1.13.0...<2.0.0 NAMES GTest
URL_HASH SHA1=${DEP_SHA1_googletest}
OVERRIDE_FIND_PACKAGE
)
endif()

Expand Down Expand Up @@ -528,4 +528,3 @@ endif()

FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR)
FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR)

199 changes: 111 additions & 88 deletions include/onnxruntime/core/framework/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,35 +44,37 @@ struct Float8E4M3FN {
std::memcpy(&b, &v, sizeof(b));

val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val |= 0x7f;
} else if ((b & 0x7fffffff) == 0x7f800000) {
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
val |= 126;
} else {
val |= 0x7f;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val |= 0x7f;
} else {
uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
if (e != 0) {
if (e < 117) { // 0b1110101
} else if (e < 118) { // 0b1110110
val |= 1;
if ((m >> 23) & 1) {
// rounding
val += 1;
if (e < 117) {
} else if (e < 121) {
// denormalized number
auto d = 120 - e;
if (d < 3) {
val |= 1 << (2 - d);
val |= m >> (21 + d);
} else if (m > 0) {
val |= 1;
}
} else if (e < 121) { // 127 - 7 + 1 // 0b1111001
auto d = 120 - e; // 0b1111000
val |= 1 << (2 - d);
val |= m >> (21 + d);
if ((m >> (20 + d)) & 1) {
auto mask = 1 << (20 + d);
if ((m & mask) &&
((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 136) { // 127 + 8 + 1 // 0b10001000
auto ex = e - 120; // 127 - 7
} else if (e < 136) {
// normalized number
auto ex = e - 120;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
Expand All @@ -83,7 +85,7 @@ struct Float8E4M3FN {
val &= 0xFE;
}
}
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7C000))) {
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7E) {
// rounding
val += 1;
Expand Down Expand Up @@ -147,14 +149,22 @@ struct Float8E4M3FN {
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { val = *reinterpret_cast<const unsigned char*>(&value); }
explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) {
val = *reinterpret_cast<const unsigned char*>(&value);
}
explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3() const { return *reinterpret_cast<const __nv_fp8_e4m3*>(&val); }
#endif
};

inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val == right.val; }
inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val != right.val; }
inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val < right.val; }
inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) {
return left.val == right.val;
}
inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) {
return left.val != right.val;
}
inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) {
return left.val < right.val;
}

// User defined suffixes to make it easier to declare
// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char
Expand All @@ -164,9 +174,7 @@ inline Float8E4M3FN operator"" _f8e4m3fn(unsigned long long int v) {
return Float8E4M3FN(narrow<uint8_t>(v), Float8E4M3FN::FromBits());
}

inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) {
return Float8E4M3FN(static_cast<float>(v), true);
}
inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) { return Float8E4M3FN(static_cast<float>(v), true); }

#endif

Expand Down Expand Up @@ -205,44 +213,46 @@ struct Float8E4M3FNUZ {
std::memcpy(&b, &v, sizeof(b));

val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val = 0x80;
} else if ((b & 0x7fffffff) == 0x7f800000) {
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
val |= 0x7F;
} else {
// infinity
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val = 0x80;
} else {
uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
if (e != 0) {
if (e < 116) {
} else if (e < 117) {
val |= 1;
if ((m >> 23) & 1) {
// rounding
val += 1;
}
} else if (e < 120) { // 127 - 8 + 1
} else if (e < 120) {
// denormalized number
auto d = 119 - e;
val |= 1 << (2 - d);
val |= m >> (21 + d);
if ((m >> (20 + d)) & 1) {
if (d < 3) {
val |= 1 << (2 - d);
val |= m >> (21 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (20 + d);
if ((m & mask) &&
((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 135) { // 127 + 8
auto ex = e - 119; // 127 - 7
} else if (e < 135) {
// normalized number
auto ex = e - 119;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
} else {
val |= ex << 3;
val |= m >> 20;
}
if (m & 0x80000) {
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
Expand Down Expand Up @@ -303,9 +313,15 @@ struct Float8E4M3FNUZ {
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
};

inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val == right.val; }
inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val != right.val; }
inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val < right.val; }
inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) {
return left.val == right.val;
}
inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) {
return left.val != right.val;
}
inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) {
return left.val < right.val;
}

// User defined suffixes to make it easier to declare
// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char
Expand All @@ -315,9 +331,7 @@ inline Float8E4M3FNUZ operator"" _f8e4m3p8fnuz(unsigned long long int v) {
return Float8E4M3FNUZ(narrow<uint8_t>(v), Float8E4M3FNUZ::FromBits());
}

inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) {
return Float8E4M3FNUZ(static_cast<float>(v), true);
}
inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) { return Float8E4M3FNUZ(static_cast<float>(v), true); }

#endif

Expand Down Expand Up @@ -357,32 +371,33 @@ struct Float8E5M2 {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));

val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val |= 0x7f;
} else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
val |= 0x7B;
} else {
val |= 0x7C;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val |= 0x7f;
} else {
uint32_t e = (b & 0x7F800000) >> 23; // exponent
uint32_t m = b & 0x007FFFFF; // mantissa

if (e != 0) {
if (e < 110) {
} else if (e < 111) {
val |= 1;
if ((m >> 23) & 1) {
// rounding
val += 1;
}
} else if (e < 113) { // 127 - 15 + 1
} else if (e < 113) {
// denormalized number
auto d = 112 - e;
val |= 1 << (1 - d);
val |= m >> (22 + d);
if ((m >> (21 + d)) & 1) {
if (d < 2) {
val |= 1 << (1 - d);
val |= m >> (22 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (21 + d);
if ((m & mask) &&
((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
Expand Down Expand Up @@ -461,8 +476,12 @@ struct Float8E5M2 {
#endif
};

inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { return left.val == right.val; }
inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { return left.val != right.val; }
inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) {
return left.val == right.val;
}
inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) {
return left.val != right.val;
}
inline ORT_HOST_DEVICE bool operator<(const Float8E5M2& left, const Float8E5M2& right) { return left.val < right.val; }

// User defined suffixes to make it easier to declare
Expand All @@ -473,9 +492,7 @@ inline Float8E5M2 operator"" _f8e5m2fn(unsigned long long int v) {
return Float8E5M2(narrow<uint8_t>(v), Float8E5M2::FromBits());
}

inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) {
return Float8E5M2(static_cast<float>(v), true);
}
inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) { return Float8E5M2(static_cast<float>(v), true); }

#endif

Expand Down Expand Up @@ -513,48 +530,50 @@ struct Float8E5M2FNUZ {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));

val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7fc00000) == 0x7fc00000) {
val = 0x80;
} else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
val |= 0x7F;
} else {
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
val = 0x80;
} else {
uint32_t e = (b & 0x7F800000) >> 23; // exponent
uint32_t m = b & 0x007FFFFF; // mantissa

if (e != 0) {
if (e < 109) {
} else if (e < 110) {
val |= 1;
if ((m >> 23) & 1) {
// rounding
val += 1;
}
} else if (e < 112) { // 127 - 16 + 1
} else if (e < 112) {
// denormalized number
auto d = 111 - e;
val |= 1 << (1 - d);
val |= m >> (22 + d);
if ((m >> (21 + d)) & 1) {
if (d < 2) {
val |= 1 << (1 - d);
val |= m >> (22 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (21 + d);
if ((m & mask) &&
((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 143) { // 127 + 15 + 1
} else if (e < 143) {
// normalized number
auto ex = e - 111;
val |= ex << 2;
val |= m >> 21;
if (m & 0x100000) {
if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
} else if (!saturate) {
val = 0x80;
}
}
} else if ((e == 255) && (m == 0)) { // inf
} else if ((e == 255) && (m == 0)) {
val = 0x80;
} else if (saturate) {
val |= 0x7F;
Expand Down Expand Up @@ -605,9 +624,15 @@ struct Float8E5M2FNUZ {
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
};

inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val == right.val; }
inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val != right.val; }
inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val < right.val; }
inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) {
return left.val == right.val;
}
inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) {
return left.val != right.val;
}
inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) {
return left.val < right.val;
}

// User defined suffixes to make it easier to declare
// initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char
Expand All @@ -617,9 +642,7 @@ inline Float8E5M2FNUZ operator"" _f8e5m2fnuz(unsigned long long int v) {
return Float8E5M2FNUZ(narrow<uint8_t>(v), Float8E5M2FNUZ::FromBits());
}

inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) {
return Float8E5M2FNUZ(static_cast<float>(v), true);
}
inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) { return Float8E5M2FNUZ(static_cast<float>(v), true); }

#endif

Expand Down
Loading

0 comments on commit 06ea28b

Please sign in to comment.