-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metric ddp bugfix #4482
Metric ddp bugfix #4482
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overwriting _apply should be done with care. There is the risk that a fp16 library such as apex calls .half() or .to(torch.float16) on the lightning module containing the metric and this reduces the precision of the metric state, thus introducing approximation errors in the metric computation. One could put an assert or warning to ensure that there was no change in the data type within the _apply method.
I agree that there is the chance of reducing the precision of the metric state, but isn't that a general concern when running fp16? |
Yes, but I don't use fp16 in the metric usually, even if I go for fp16 computation in the network. So, I would not expect the metric to be cast to fp16 in general. To my knowledge, using native fp16 with lightning is safe from this perspective because it uses autocast in the forward pass, but it might be that if one uses the apex backend, an explicit casting of the model to fp16 might occur. |
I didn't realize buffers got copied across DDP processes... is there any documentation of this? Not sure why this wouldn't have broken the tests... |
|
Are all metrics going to have their states as parameters now by default? |
@s-rog, no their states will be simple attributes of metric (so non-trainable). |
Hello @SkafteNicki! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-11-10 07:36:47 UTC |
Codecov Report
@@ Coverage Diff @@
## master #4482 +/- ##
======================================
- Coverage 93% 93% -0%
======================================
Files 116 116
Lines 8837 8855 +18
======================================
+ Hits 8241 8245 +4
- Misses 596 610 +14 |
Ah thanks @rotabulo! I believe tests must be passing as we are never performing a |
it seems that it significantly increases the running time
for |
@Borda I don´t understand how this should influence the parity test in any way as these are not using our Metric package |
* changes * fix spelling * small note * trying to fix ddp test * fix ddp * fix for test * suggestion * CHANGELOG * Update pytorch_lightning/metrics/metric.py Co-authored-by: chaton <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Sean Naren <[email protected]> (cherry picked from commit 465ec75)
* changes * fix spelling * small note * trying to fix ddp test * fix ddp * fix for test * suggestion * CHANGELOG * Update pytorch_lightning/metrics/metric.py Co-authored-by: chaton <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Sean Naren <[email protected]> (cherry picked from commit 465ec75)
* changes * fix spelling * small note * trying to fix ddp test * fix ddp * fix for test * suggestion * CHANGELOG * Update pytorch_lightning/metrics/metric.py Co-authored-by: chaton <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Sean Naren <[email protected]> (cherry picked from commit 465ec75)
* changes * fix spelling * small note * trying to fix ddp test * fix ddp * fix for test * suggestion * CHANGELOG * Update pytorch_lightning/metrics/metric.py Co-authored-by: chaton <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Sean Naren <[email protected]>
What does this PR do?
Fixes
#3728
#4396
#4361
Metric calculations does not work in ddp mode. Because we register metrics as buffers (using
self.register_buffer
), on each forward pass the buffer of rank0 will be copied/override the buffers of the other ranks, leading to wrong accumulation. Running this simple example show what is wrong (thanks to @rotabulo):the correct result will be 45 (sum of digits 0-9) but the calculated result is 41.
The solution is to not add metric states as buffers, however we still want to keep the advantages of that buffers have (auto device agnostic, part of state dict). This PR implements that.
Additionally, a
.persistent(mode)
method is added such that the user can easy determine if they want metrics to be added to state dicts or not (after initialization).I have no idea why our current metric test does not show this problem.
Still need more testing and documentation...
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃