-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA BFloat16 Refactor #10085
CUDA BFloat16 Refactor #10085
Changes from 9 commits
2b3f8f7
8815c58
5467e8e
2cfec11
2b8868e
7dc37da
d071935
a3c394a
882cb35
8c3240d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,19 @@ | |
#pragma once | ||
|
||
#include "endian.h" | ||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | ||
#include "cuda_bf16.h" | ||
#endif | ||
|
||
namespace onnxruntime | ||
{ | ||
#include "core/common/common.h" | ||
|
||
namespace onnxruntime { | ||
|
||
#if defined(__CUDACC__) || defined(__HIPCC__) | ||
#define ORT_HOST_DEVICE __host__ __device__ | ||
#else | ||
#define ORT_HOST_DEVICE | ||
#endif | ||
|
||
// MLFloat16 | ||
struct MLFloat16 { | ||
|
@@ -17,53 +27,64 @@ struct MLFloat16 { | |
|
||
float ToFloat() const; | ||
|
||
operator float() const { | ||
return ToFloat(); | ||
} | ||
operator float() const { return ToFloat(); } | ||
}; | ||
|
||
inline bool operator==(const MLFloat16& left, const MLFloat16& right) { | ||
return left.val == right.val; | ||
} | ||
inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; } | ||
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; } | ||
inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; } | ||
|
||
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { | ||
return left.val != right.val; | ||
} | ||
|
||
inline bool operator<(const MLFloat16& left, const MLFloat16& right) { | ||
return left.val < right.val; | ||
} | ||
|
||
//BFloat16 | ||
// BFloat16 | ||
struct BFloat16 { | ||
uint16_t val{0}; | ||
explicit BFloat16() = default; | ||
explicit BFloat16(uint16_t v) : val(v) {} | ||
explicit BFloat16(float v) { | ||
#if defined(USE_ROCM) | ||
ORT_HOST_DEVICE BFloat16() = default; | ||
#else | ||
BFloat16() = default; | ||
#endif | ||
|
||
struct FromBitsT {}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reason to introduce struct FromBitsT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This idea is from PyTorch. It means if it's initialized from FromBitsT, then the bits will assign to val directly (the real value of BFloat16 instance is not equal to bits), but if not, for example, BFloat16(unsigned short value), it will initialize a BFloat16 == value (but the val member in the object is not equal to value). This is critical for some casting case, for example, BFloat16(1), which casts int to BFloat16, if we don't have this FromBitsT, the complier will report error saying ambiguous constructors, it doesn't know which to choose from BFloat16(unsigned short) or BFloat16(float). Even we don't have such ambigous problem, if compiler chooses BFloat(unsigned short) to do the job but assign the 1 to val memer directly, we would get a wrong BFloat16 instance. Actually our MLFloat16 also has such bug, but we don't have code such as MLFloat16(1) so we haven't encountered the compiler error for now. |
||
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } | ||
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits){}; | ||
|
||
inline ORT_HOST_DEVICE BFloat16(float v) { | ||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||
val = __bfloat16_as_ushort(__float2bfloat16(v)); | ||
#else | ||
ORT_IF_CONSTEXPR(endian::native == endian::little) { | ||
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t)); | ||
} else { | ||
} | ||
else { | ||
std::memcpy(&val, &v, sizeof(uint16_t)); | ||
} | ||
#endif | ||
} | ||
|
||
float ToFloat() const { | ||
inline ORT_HOST_DEVICE float ToFloat() const { | ||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | ||
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val)); | ||
#else | ||
float result; | ||
char* const first = reinterpret_cast<char*>(&result); | ||
char* const second = first + sizeof(uint16_t); | ||
ORT_IF_CONSTEXPR(endian::native == endian::little) { | ||
std::memset(first, 0, sizeof(uint16_t)); | ||
std::memcpy(second, &val, sizeof(uint16_t)); | ||
} else { | ||
} | ||
else { | ||
std::memcpy(first, &val, sizeof(uint16_t)); | ||
std::memset(second, 0, sizeof(uint16_t)); | ||
} | ||
return result; | ||
#endif | ||
} | ||
|
||
operator float() const { | ||
return ToFloat(); | ||
} | ||
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } | ||
|
||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | ||
ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); } | ||
explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); } | ||
#endif | ||
}; | ||
|
||
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) { | ||
|
@@ -82,16 +103,4 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) { | |
} | ||
} | ||
|
||
inline bool operator==(const BFloat16& left, const BFloat16& right) { | ||
return left.val == right.val; | ||
} | ||
|
||
inline bool operator!=(const BFloat16& left, const BFloat16& right) { | ||
return left.val != right.val; | ||
} | ||
|
||
inline bool operator<(const BFloat16& left, const BFloat16& right) { | ||
return left.val < right.val; | ||
} | ||
|
||
} | ||
} // namespace onnxruntime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, why is the line above specific to ROCM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw PyTorch does this way for all default constructors so I followed the same way. Maybe hipcc requires this? But I didn't find out any documentation to support this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's leave it as it is and will re-visit when supporting BF16 on AMD GPU.