Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric ddp bugfix #4482

Merged
merged 19 commits into from
Nov 10, 2020
Merged

Conversation

SkafteNicki
Copy link
Member

What does this PR do?

Fixes
#3728
#4396
#4361

Metric calculations does not work in ddp mode. Because we register metrics as buffers (using self.register_buffer), on each forward pass the buffer of rank0 will be copied/override the buffers of the other ranks, leading to wrong accumulation. Running this simple example show what is wrong (thanks to @rotabulo):

import torch
import time

import pytorch_lightning as pl

class SumMetric(pl.metrics.Metric):
    def __init__(self):
        super().__init__(dist_sync_on_step=False)
        self.add_state('sum',default=torch.tensor(0.), dist_reduce_fx='sum')

    def update(self, preds, target=None):
        self.sum+=preds.sum()

    def compute(self):
        return self.sum

data=torch.arange(10)[:,None].float()

dataset = torch.utils.data.TensorDataset(data)


class IdentityModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.metric=SumMetric()
        # this is here just to have some fake parameters in the model
        self.params=torch.nn.Conv2d(10,10,1)

    def forward(self, x):
        return x

    def training_step(self, batch, batch_idx):
        self.metric(batch[0])
        return None

    def training_epoch_end(self, training_step_outputs):
        print(f"Computed sum: {self.metric.compute()} GPU {torch.distributed.get_rank()}")
        print(f"Actual sum: {data.sum()} GPU {torch.distributed.get_rank()}")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset, batch_size=1, sampler=torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False))

    def configure_optimizers(self):
        return None


pl.Trainer(gpus=2, max_epochs=1,
           distributed_backend='ddp', progress_bar_refresh_rate=0,
           replace_sampler_ddp=False).fit(IdentityModel())

the correct result will be 45 (sum of digits 0-9) but the calculated result is 41.

The solution is to not add metric states as buffers, however we still want to keep the advantages of that buffers have (auto device agnostic, part of state dict). This PR implements that.
Additionally, a .persistent(mode) method is added such that the user can easy determine if they want metrics to be added to state dicts or not (after initialization).

I have no idea why our current metric test does not show this problem.

Still need more testing and documentation...

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Copy link

@rotabulo rotabulo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overwriting _apply should be done with care. There is the risk that a fp16 library such as apex calls .half() or .to(torch.float16) on the lightning module containing the metric and this reduces the precision of the metric state, thus introducing approximation errors in the metric computation. One could put an assert or warning to ensure that there was no change in the data type within the _apply method.

@SkafteNicki
Copy link
Member Author

Overwriting _apply should be done with care. There is the risk that a fp16 library such as apex calls .half() or .to(torch.float16) on the lightning module containing the metric and this reduces the precision of the metric state, thus introducing approximation errors in the metric computation. One could put an assert or warning to ensure that there was no change in the data type within the _apply method.

I agree that there is the chance of reducing the precision of the metric state, but isn't that a general concern when running fp16?

@rotabulo
Copy link

rotabulo commented Nov 2, 2020

Overwriting _apply should be done with care. There is the risk that a fp16 library such as apex calls .half() or .to(torch.float16) on the lightning module containing the metric and this reduces the precision of the metric state, thus introducing approximation errors in the metric computation. One could put an assert or warning to ensure that there was no change in the data type within the _apply method.

I agree that there is the chance of reducing the precision of the metric state, but isn't that a general concern when running fp16?

Yes, but I don't use fp16 in the metric usually, even if I go for fp16 computation in the network. So, I would not expect the metric to be cast to fp16 in general. To my knowledge, using native fp16 with lightning is safe from this perspective because it uses autocast in the forward pass, but it might be that if one uses the apex backend, an explicit casting of the model to fp16 might occur.

@teddykoker
Copy link
Contributor

I didn't realize buffers got copied across DDP processes... is there any documentation of this? Not sure why this wouldn't have broken the tests...

@rotabulo
Copy link

rotabulo commented Nov 2, 2020

I didn't realize buffers got copied across DDP processes... is there any documentation of this? Not sure why this wouldn't have broken the tests...

Here in a note: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=distributeddataparallel#torch.nn.parallel.DistributedDataParallel

Parameters are never broadcast between processes. The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way. Buffers (e.g. BatchNorm stats) are broadcast from the module in process of rank 0, to all other replicas in the system in every iteration.

@rohitgr7 rohitgr7 self-requested a review November 2, 2020 17:15
@s-rog
Copy link
Contributor

s-rog commented Nov 3, 2020

Are all metrics going to have their states as parameters now by default?

@SkafteNicki
Copy link
Member Author

@s-rog, no their states will be simple attributes of metric (so non-trainable).
However, they will behave similar to pytorch buffers, since they automatically will move to the correct device (when using to(...)) and the state will also automatically be added to metric.state_dict().

@pep8speaks
Copy link

pep8speaks commented Nov 3, 2020

Hello @SkafteNicki! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-11-10 07:36:47 UTC

@codecov
Copy link

codecov bot commented Nov 3, 2020

Codecov Report

Merging #4482 (e1e7935) into master (4f3160b) will decrease coverage by 0%.
The diff coverage is 54%.

@@          Coverage Diff           @@
##           master   #4482   +/-   ##
======================================
- Coverage      93%     93%   -0%     
======================================
  Files         116     116           
  Lines        8837    8855   +18     
======================================
+ Hits         8241    8245    +4     
- Misses        596     610   +14     

@teddykoker
Copy link
Contributor

Ah thanks @rotabulo! I believe tests must be passing as we are never performing a backward step within the tests, so the buffers are never registered.

pytorch_lightning/metrics/metric.py Outdated Show resolved Hide resolved
@SkafteNicki SkafteNicki changed the title [WIP] metric ddp bugfix Metric ddp bugfix Nov 9, 2020
@Borda
Copy link
Member

Borda commented Nov 9, 2020

it seems that it significantly increases the running time

pl_times = [22.039139796979725, 21.899155146093108]
pt_times = [18.828625079942867, 19.016960786888376]

for FAILED benchmarks/test_parity.py::test_pytorch_parity[ParityModuleMNIST-0.8]
can we profile what is the main source of latency?

@SkafteNicki
Copy link
Member Author

@Borda I don´t understand how this should influence the parity test in any way as these are not using our Metric package

@SkafteNicki SkafteNicki merged commit 465ec75 into Lightning-AI:master Nov 10, 2020
@SkafteNicki SkafteNicki deleted the metrics/fix_states branch November 10, 2020 08:16
SeanNaren pushed a commit that referenced this pull request Nov 10, 2020
* changes

* fix spelling

* small note

* trying to fix ddp test

* fix ddp

* fix for test

* suggestion

* CHANGELOG

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: chaton <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>

(cherry picked from commit 465ec75)
SeanNaren pushed a commit that referenced this pull request Nov 11, 2020
* changes

* fix spelling

* small note

* trying to fix ddp test

* fix ddp

* fix for test

* suggestion

* CHANGELOG

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: chaton <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
(cherry picked from commit 465ec75)
Borda pushed a commit that referenced this pull request Nov 11, 2020
* changes

* fix spelling

* small note

* trying to fix ddp test

* fix ddp

* fix for test

* suggestion

* CHANGELOG

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: chaton <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
(cherry picked from commit 465ec75)
rohitgr7 pushed a commit that referenced this pull request Nov 21, 2020
* changes

* fix spelling

* small note

* trying to fix ddp test

* fix ddp

* fix for test

* suggestion

* CHANGELOG

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: chaton <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
10 participants