Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA BFloat16 Refactor #10085

Merged
merged 10 commits into from
Jan 14, 2022
Merged

CUDA BFloat16 Refactor #10085

merged 10 commits into from
Jan 14, 2022

Conversation

centwang
Copy link
Contributor

@centwang centwang commented Dec 20, 2021

Previous code casted BFloat16 to CUDA's nv_bfloat16 type for calculation, which required A100 to run because nv_bfloat16's calculation can run on A100 only. PyTorch uses its own type c10::BFloat16 for calculation. This PR is to refactor our code to follow the same idea to use our own onnxruntime::BFloat16 for calculation. The general implemtation is to cast BFloat16 to float for calculation, and use nv_bfloat16 on A100 using macro CUDA_ARCH >= 800.

With this implementation, we can support BFloat16 on most of the Nvidia devices besides A100.

Tested the code using ORTModule in two ways (need latest nightly PyTorch and ONNX for some BFloat16 support):

  • Add cast to torch.bfloat16 in the Module, this can run on both V100 and A100, and can get same calculation results
  • Use torch.autocast. PyTorch supports BFloat16 autocast on A100 only. I tested both PyTorch and ORT using torch.autocast on A100 to run the BERT model from transformers. We can get the same result (ignoring the margin of error), and ORT's perf is better than PyTorch (same as autocast of Float16).

Note that PyTorch also uses its own type c10::Float16 type for float16 calculation in CUDA, but ORT casts to CUDA's half type. This is OK as half is supported by most of the Nvidia devices. This PR doesn't torch any logic related to the float16 case.

@centwang centwang added the training issues related to ONNX Runtime training; typically submitted using template label Dec 20, 2021
Copy link
Contributor

@weixingzhang weixingzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making change. In general, looks good to me.

explicit BFloat16(uint16_t v) : val(v) {}
explicit BFloat16(float v) {
#if defined(USE_ROCM)
ORT_HOST_DEVICE BFloat16() = default;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why is the line above specific to ROCM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw PyTorch does this way for all default constructors so I followed the same way. Maybe hipcc requires this? But I didn't find out any documentation to support this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's leave it as it is and will re-visit when supporting BF16 on AMD GPU.

BFloat16() = default;
#endif

struct FromBitsT {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason to introduce struct FromBitsT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This idea is from PyTorch. It means if it's initialized from FromBitsT, then the bits will assign to val directly (the real value of BFloat16 instance is not equal to bits), but if not, for example, BFloat16(unsigned short value), it will initialize a BFloat16 == value (but the val member in the object is not equal to value). This is critical for some casting case, for example, BFloat16(1), which casts int to BFloat16, if we don't have this FromBitsT, the complier will report error saying ambiguous constructors, it doesn't know which to choose from BFloat16(unsigned short) or BFloat16(float). Even we don't have such ambigous problem, if compiler chooses BFloat(unsigned short) to do the job but assign the 1 to val memer directly, we would get a wrong BFloat16 instance. Actually our MLFloat16 also has such bug, but we don't have code such as MLFloat16(1) so we haven't encountered the compiler error for now.

@centwang centwang merged commit 44e2db9 into master Jan 14, 2022
@centwang centwang deleted the weicwang/bfloat16 branch January 14, 2022 11:38
This was referenced Jan 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants