-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA BFloat16 Refactor #10085
CUDA BFloat16 Refactor #10085
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making change. In general, looks good to me.
explicit BFloat16(uint16_t v) : val(v) {} | ||
explicit BFloat16(float v) { | ||
#if defined(USE_ROCM) | ||
ORT_HOST_DEVICE BFloat16() = default; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, why is the line above specific to ROCM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw PyTorch does this way for all default constructors so I followed the same way. Maybe hipcc requires this? But I didn't find out any documentation to support this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's leave it as it is and will re-visit when supporting BF16 on AMD GPU.
BFloat16() = default; | ||
#endif | ||
|
||
struct FromBitsT {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason to introduce struct FromBitsT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This idea is from PyTorch. It means if it's initialized from FromBitsT, then the bits will assign to val directly (the real value of BFloat16 instance is not equal to bits), but if not, for example, BFloat16(unsigned short value), it will initialize a BFloat16 == value (but the val member in the object is not equal to value). This is critical for some casting case, for example, BFloat16(1), which casts int to BFloat16, if we don't have this FromBitsT, the complier will report error saying ambiguous constructors, it doesn't know which to choose from BFloat16(unsigned short) or BFloat16(float). Even we don't have such ambigous problem, if compiler chooses BFloat(unsigned short) to do the job but assign the 1 to val memer directly, we would get a wrong BFloat16 instance. Actually our MLFloat16 also has such bug, but we don't have code such as MLFloat16(1) so we haven't encountered the compiler error for now.
Previous code casted BFloat16 to CUDA's nv_bfloat16 type for calculation, which required A100 to run because nv_bfloat16's calculation can run on A100 only. PyTorch uses its own type c10::BFloat16 for calculation. This PR is to refactor our code to follow the same idea to use our own onnxruntime::BFloat16 for calculation. The general implemtation is to cast BFloat16 to float for calculation, and use nv_bfloat16 on A100 using macro CUDA_ARCH >= 800.
With this implementation, we can support BFloat16 on most of the Nvidia devices besides A100.
Tested the code using ORTModule in two ways (need latest nightly PyTorch and ONNX for some BFloat16 support):
Note that PyTorch also uses its own type c10::Float16 type for float16 calculation in CUDA, but ORT casts to CUDA's half type. This is OK as half is supported by most of the Nvidia devices. This PR doesn't torch any logic related to the float16 case.