Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bfloat16 support to Lightning Trainer #9049

Merged
merged 15 commits into from
Aug 24, 2021
Merged

Add bfloat16 support to Lightning Trainer #9049

merged 15 commits into from
Aug 24, 2021

Conversation

SeanNaren
Copy link
Contributor

@SeanNaren SeanNaren commented Aug 23, 2021

What does this PR do?

Closes #8874.

Adds bfloat16 support which currently is only available in PyTorch 1.10.0dev (can be installed via pytorch-nightly). This is a demanded feature primarily to support XLA models that have been trained in bfloat16 for pure PyTorch (lots of HF models have been taken from google trained sources).

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@SeanNaren SeanNaren added the feature Is an improvement or enhancement label Aug 23, 2021
@SeanNaren SeanNaren added this to the v1.5 milestone Aug 23, 2021
@SeanNaren SeanNaren self-assigned this Aug 23, 2021
@codecov
Copy link

codecov bot commented Aug 23, 2021

Codecov Report

Merging #9049 (92de556) into master (1e4d892) will decrease coverage by 1%.
The diff coverage is 89%.

@@           Coverage Diff           @@
##           master   #9049    +/-   ##
=======================================
- Coverage      93%     92%    -1%     
=======================================
  Files         175     178     +3     
  Lines       14505   14687   +182     
=======================================
+ Hits        13449   13519    +70     
- Misses       1056    1168   +112     

pytorch_lightning/plugins/precision/native_amp.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
tests/plugins/test_amp_plugins.py Outdated Show resolved Hide resolved
tests/plugins/test_amp_plugins.py Outdated Show resolved Hide resolved
@mergify mergify bot removed the has conflicts label Aug 23, 2021
@awaelchli
Copy link
Contributor

@SeanNaren how will we handle the precision input from argparse? we may have to allow passing the 16 and 32 as str too.

import torch
from argparse import ArgumentParser
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, train_data)


if __name__ == "__main__":
    run()

python train.py --precision 16 --gpus 1 gives the following error:

adrian@lambda-server4 ~/r/pytorch-lightning (feat/bfloat16)> python x.py --precision 16 --gpus 1                                                                 (pl)
Traceback (most recent call last):
  File "x.py", line 61, in <module>
    run()
  File "x.py", line 56, in run
    trainer = Trainer.from_argparse_args(args)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/properties.py", line 426, in from_argparse_args
    return from_argparse_args(cls, args, **kwargs)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/utilities/argparse.py", line 65, in from_argparse_args
    return cls(**trainer_kwargs)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/env_vars_connector.py", line 40, in insert_env_defaults
    return fn(self, **kwargs)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 352, in __init__
    self.accelerator_connector = AcceleratorConnector(
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 164, in __init__
    self.accelerator = self.select_accelerator()
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 713, in select_accelerator
    accelerator = acc_cls(training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 344, in precision_plugin
    self._precision_plugin = self.select_precision_plugin()
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 604, in select_precision_plugin
    raise NotImplementedError("We only support precisions 64, 32 and 16!")
NotImplementedError: We only support precisions 64, 32 and 16!

while on master it works fine.

@mergify mergify bot removed the has conflicts label Aug 23, 2021
Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren SeanNaren enabled auto-merge (squash) August 23, 2021 22:07
@Borda Borda requested a review from ananthsub August 24, 2021 07:07
@SeanNaren SeanNaren mentioned this pull request Aug 24, 2021
12 tasks
@SeanNaren SeanNaren merged commit 1feec8c into master Aug 24, 2021
@SeanNaren SeanNaren deleted the feat/bfloat16 branch August 24, 2021 09:47
@yuvalkirstain
Copy link

yuvalkirstain commented Aug 24, 2021

@SeanNaren
Thanks so much for this feature!!!!

I installed pytorch-nightly (version 1.10.0.dev20210824) and installed the bleeding edge of pytorch-lightning. However, I get the following error:

pytorch_lightning.utilities.exceptions.MisconfigurationException: Error instantiating 'pytorch_lightning.trainer.trainer.Trainer' : To use bfloat16 with native amp you must install torch greater or equal to 1.10.

When I explicitly check I get:

from pytorch_lightning.utilities.imports import operator, _compare_version
_compare_version("torch", operator.ge, "1.10.0")
>> False
_compare_version("torch", operator.ge, "1.10.0.dev20210824")
>> True

Did you verify that bf16 works? If so, how can I download the torch version that you have used?
Also, how much memory and time does this feature save? For example when testing it with T5?
If there is an easy way to check it, I can help :)

@SeanNaren SeanNaren mentioned this pull request Aug 24, 2021
12 tasks
@SeanNaren
Copy link
Contributor Author

Thanks @yuvalkirstain!

I have a followup PR to fix this, apologies on the issue with the import: #9089

@yuvalkirstain
Copy link

@SeanNaren Sure!
I'd love to help with benchmarking this new feature.
I have a Lightning Transformers script ready to fine-tune T5 and bunch of A100 :)
With a bit of guidance I'll be able to benchmark how effective this feature is.

@SeanNaren
Copy link
Contributor Author

@yuvalkirstain how coincidental! I also have the same :)

I'm running google/mt5 with this command to benchmark:

python train.py task=nlp/translation dataset=nlp/translation/wmt16 trainer.gpus=8  backbone.pretrained_model_name_or_path=google/mt5-base training.batch_size=4 trainer.accelerator=ddp trainer.limit_train_batches=64

I unfortunately do not have A100s and running on 3090s, which I think are slower :( but I'm not seeing nan losses for the model (albeit loss seems to be stagnating, need to check that out!)

@yuvalkirstain
Copy link

Great!
Perhaps it might be worthwhile to also benchmark the memory consumption and runtime with and without the use of bf16.

@cowwoc
Copy link
Contributor

cowwoc commented Oct 14, 2021

Out of curiosity, why does BFloat16 support require PyTorch 1.10? Per https://github.com/pytorch/pytorch/releases it should be supported as of version 1.9. Am I missing something?

@SeanNaren
Copy link
Contributor Author

Out of curiosity, why does BFloat16 support require PyTorch 1.10? Per https://github.com/pytorch/pytorch/releases it should be supported as of version 1.9. Am I missing something?

This might be incorrect, but I think for a few layers (such as conv) bfloat 16 support were only added in 1.10.

@leezu
Copy link
Contributor

leezu commented Oct 14, 2021

@cowwoc 1.9 adds bfloat16 support to some operators. But without 1.10, you're missing the bfloat16 autocast support (set_autocast_gpu_dtype/get_autocast_gpu_dtype) pytorch/pytorch@324673a

@zplizzi
Copy link

zplizzi commented Oct 22, 2021

To use this, you just set precision="bfloat16" in the Trainer, right? Does this basically just cast everything to bfloat16, or is there some sort of fp32 accumulation going on?

It would be great if there was some documentation for this new precision support :) Maybe here?

@SeanNaren
Copy link
Contributor Author

Information for the docs can be seen here: https://pytorch-lightning.readthedocs.io/en/latest/advanced/mixed_precision.html#bfloat16-mixed-precision but it's outdated since 1.10 is out! you pass bf16 to the Trainer :)

I agree, more information should be added to the docs! will make an issue

@zplizzi
Copy link

zplizzi commented Oct 23, 2021

Thanks! I also wanted to clarify if this is bfloat16 multiply / fp32 accumulate, or bfloat16 for everything? On TPUs, they do the former (source):

More precisely, each multiply-accumulate operation in a matrix multiplication uses bfloat16 for the multiplication and 32-bit IEEE floating point for accumulation.

If it's the latter, is there some kind of reference showing how well typical models train using bfloat16 exclusively? Because with fp16 we had to use NVIDIA APEX fp32 accumulation to get good results - people say bfloat16 is better, but is it enough better to not need fp32 accumulation? I'd be interested to see if people have looked into that.

For reference, my understanding is that on Ampere, exclusively using bfloat16 is a 4x speedup from single-precision, and using bfloat16 multiply/fp32 accumulate is a 2x speedup from single precision. Using "tf32" (19-bit floating point) and fp32 accumulation is also a 2x speedup, so if we have to do fp32 accumulation, I'm not sure I see the advantage of training models in bfloat16 over tf32.

Update: while bfloat16 with fp32 accumulation doesn't have a FLOPs advantage over tf32 multiply (now enabled by default in pytorch >=1.7), it does use ~2x less memory, so you can fit larger batches and use less memory bandwidth.

@SeanNaren
Copy link
Contributor Author

@zplizzi I spent some time going through documentation etc. I think overall I'm convinced their is no accumulation effect happening for both FP16 and bfloat16 autocast. Autocast handles the instability slightly differently, by having a supported list of ops that can run safely in FP16, ones that can upsample to FP32 and ones that can only run in FP32.

I may be mistaken, but if this is the case we should label bfloat16 as experimental as it differs slightly from the norm (as you described) and we should benchmark to see the true performance.

@zplizzi
Copy link

zplizzi commented Nov 3, 2021

Ah, I think I am understanding better now. In the original NVIDIA AMP package (https://nvidia.github.io/apex/amp.html) there were different versions of mixed precision. "O1" seems to be what was carried into the native pytorch AMP package - everything stays fp32 except certain operations are performed in fp16. "O2" I think is more like how TPUs work (not sure) - the model is cast to fp16 and everything (except certain blacklisted ops) happens in fp16, including grads (?), and in the optimizer step the weight updates are applied to a separate fp32 version of the weights ("fp32 accumulation").

I'm not too sure how you implemented bfloat16 in this PR, though - I don't see anywhere in the pytorch AMP docs where they mention AMP supporting bfloat16. I'd like to learn how this works.

@zplizzi
Copy link

zplizzi commented Nov 3, 2021

Ah, I see that you're still using native AMP, just passing in bfloat16 as the fast_dtype instead of the default fp16. So the op whitelist/blacklist is still being applied. And you're also disabling the grad scaling.

I'm honestly not sure if this is entirely right or not. The whitelist/blacklist was created specifically with fp16 in mind, and the tradeoffs of bfloat16 are different (more dynamic range, less precision). But, it's probably approximately correct.

@yuvalkirstain
Copy link

yuvalkirstain commented Mar 13, 2022

@SeanNaren When I use this feature and check the dtype of the model, it seems like the model's precision is fp32 (and I do not see the memory gains I expect). On other frameworks that support bf16 (like fairseq) the model's dtype is torch.bfloat16. Is there a simple example that "proves" that this feature reduces the memory consumption as it should? I suspect that there might be something wrong (but of course, I might be wrong).
Thank you!

@SeanNaren
Copy link
Contributor Author

@yuvalkirstain do you mind making a new issue for this? should be something we track!

@yuvalkirstain
Copy link

yuvalkirstain commented Apr 3, 2022

@SeanNaren sure thing :)
I opened a new issue 12591

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

support for bf16 in lightning trainer