-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bfloat16 support to Lightning Trainer #9049
Conversation
Codecov Report
@@ Coverage Diff @@
## master #9049 +/- ##
=======================================
- Coverage 93% 92% -1%
=======================================
Files 175 178 +3
Lines 14505 14687 +182
=======================================
+ Hits 13449 13519 +70
- Misses 1056 1168 +112 |
# Conflicts: # CHANGELOG.md
@SeanNaren how will we handle the precision input from argparse? we may have to allow passing the 16 and 32 as str too. import torch
from argparse import ArgumentParser
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, train_data)
if __name__ == "__main__":
run()
adrian@lambda-server4 ~/r/pytorch-lightning (feat/bfloat16)> python x.py --precision 16 --gpus 1 (pl)
Traceback (most recent call last):
File "x.py", line 61, in <module>
run()
File "x.py", line 56, in run
trainer = Trainer.from_argparse_args(args)
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/properties.py", line 426, in from_argparse_args
return from_argparse_args(cls, args, **kwargs)
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/utilities/argparse.py", line 65, in from_argparse_args
return cls(**trainer_kwargs)
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/env_vars_connector.py", line 40, in insert_env_defaults
return fn(self, **kwargs)
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 352, in __init__
self.accelerator_connector = AcceleratorConnector(
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 164, in __init__
self.accelerator = self.select_accelerator()
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 713, in select_accelerator
accelerator = acc_cls(training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin)
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 344, in precision_plugin
self._precision_plugin = self.select_precision_plugin()
File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 604, in select_precision_plugin
raise NotImplementedError("We only support precisions 64, 32 and 16!")
NotImplementedError: We only support precisions 64, 32 and 16! while on master it works fine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need to update assertions here? https://github.com/PyTorchLightning/pytorch-lightning/blob/f3c5889aa3116bb8299a849f4f02dfb535b5624e/pytorch_lightning/accelerators/cpu.py#L29-L35
for cpu
pytorch/pytorch#57386 ?
Co-authored-by: Sean Naren <[email protected]>
@SeanNaren I installed pytorch-nightly (version
When I explicitly check I get:
Did you verify that |
Thanks @yuvalkirstain! I have a followup PR to fix this, apologies on the issue with the import: #9089 |
@SeanNaren Sure! |
@yuvalkirstain how coincidental! I also have the same :) I'm running google/mt5 with this command to benchmark:
I unfortunately do not have A100s and running on 3090s, which I think are slower :( but I'm not seeing nan losses for the model (albeit loss seems to be stagnating, need to check that out!) |
Great! |
Out of curiosity, why does BFloat16 support require PyTorch 1.10? Per https://github.com/pytorch/pytorch/releases it should be supported as of version 1.9. Am I missing something? |
This might be incorrect, but I think for a few layers (such as conv) bfloat 16 support were only added in 1.10. |
@cowwoc 1.9 adds bfloat16 support to some operators. But without 1.10, you're missing the bfloat16 autocast support (set_autocast_gpu_dtype/get_autocast_gpu_dtype) pytorch/pytorch@324673a |
To use this, you just set It would be great if there was some documentation for this new precision support :) Maybe here? |
Information for the docs can be seen here: https://pytorch-lightning.readthedocs.io/en/latest/advanced/mixed_precision.html#bfloat16-mixed-precision but it's outdated since 1.10 is out! you pass I agree, more information should be added to the docs! will make an issue |
Thanks! I also wanted to clarify if this is bfloat16 multiply / fp32 accumulate, or bfloat16 for everything? On TPUs, they do the former (source):
If it's the latter, is there some kind of reference showing how well typical models train using bfloat16 exclusively? Because with fp16 we had to use NVIDIA APEX fp32 accumulation to get good results - people say bfloat16 is better, but is it enough better to not need fp32 accumulation? I'd be interested to see if people have looked into that. For reference, my understanding is that on Ampere, exclusively using bfloat16 is a 4x speedup from single-precision, and using bfloat16 multiply/fp32 accumulate is a 2x speedup from single precision. Using "tf32" (19-bit floating point) and fp32 accumulation is also a 2x speedup, so if we have to do fp32 accumulation, I'm not sure I see the advantage of training models in bfloat16 over tf32. Update: while bfloat16 with fp32 accumulation doesn't have a FLOPs advantage over tf32 multiply (now enabled by default in pytorch >=1.7), it does use ~2x less memory, so you can fit larger batches and use less memory bandwidth. |
@zplizzi I spent some time going through documentation etc. I think overall I'm convinced their is no accumulation effect happening for both FP16 and bfloat16 autocast. Autocast handles the instability slightly differently, by having a supported list of ops that can run safely in FP16, ones that can upsample to FP32 and ones that can only run in FP32. I may be mistaken, but if this is the case we should label bfloat16 as experimental as it differs slightly from the norm (as you described) and we should benchmark to see the true performance. |
Ah, I think I am understanding better now. In the original NVIDIA AMP package (https://nvidia.github.io/apex/amp.html) there were different versions of mixed precision. "O1" seems to be what was carried into the native pytorch AMP package - everything stays fp32 except certain operations are performed in fp16. "O2" I think is more like how TPUs work (not sure) - the model is cast to fp16 and everything (except certain blacklisted ops) happens in fp16, including grads (?), and in the optimizer step the weight updates are applied to a separate fp32 version of the weights ("fp32 accumulation"). I'm not too sure how you implemented bfloat16 in this PR, though - I don't see anywhere in the pytorch AMP docs where they mention AMP supporting bfloat16. I'd like to learn how this works. |
Ah, I see that you're still using native AMP, just passing in bfloat16 as the fast_dtype instead of the default fp16. So the op whitelist/blacklist is still being applied. And you're also disabling the grad scaling. I'm honestly not sure if this is entirely right or not. The whitelist/blacklist was created specifically with fp16 in mind, and the tradeoffs of bfloat16 are different (more dynamic range, less precision). But, it's probably approximately correct. |
@SeanNaren When I use this feature and check the |
@yuvalkirstain do you mind making a new issue for this? should be something we track! |
@SeanNaren sure thing :) |
What does this PR do?
Closes #8874.
Adds bfloat16 support which currently is only available in PyTorch 1.10.0dev (can be installed via pytorch-nightly). This is a demanded feature primarily to support XLA models that have been trained in bfloat16 for pure PyTorch (lots of HF models have been taken from google trained sources).
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃