-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TPUPrecisionPlugin
#10020
Add TPUPrecisionPlugin
#10020
Conversation
|
||
|
||
class TPUHalfPrecisionPlugin(PrecisionPlugin): | ||
class TPUHalfPrecisionPlugin(TPUPrecisionPlugin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SeanNaren why was this not done as
class TPUBf16PrecisionPlugin():
precision: str = "bf16"
# accelerator connector
if self.precision == 16:
raise Unsupported
elif self.precision == "bf16":
return TPUBf16PrecisionPlugin
The current format can be confusing for users, perhaps making them believe AMP and TPU work together
We can easily change this, as this is unreleased
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, I don't get why the NativeMixedPrecisionPlugin
was used for bf16
support, considering it doesn't even define the scaler
. Why wasn't a Bf16PrecisionPlugin
used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment does not block this PR.
I've opened a branch with my ideas: refactor/tpu-precision-plugin...refactor/untangle-bf16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the first question is "why do we make it seem precision=16 is possible for TPUs", is because this was the original default before "bf16" was introduced AFAIK. we could go ahead and deprecate precision=16
for TPUs however precision=16
doesn't even make sense for AMP, so should we deprecate this and go for precision='mixed'
? I don't think so personally, and would advise against
Regarding the question of having two separate plugins for BF16 and for Native AMP, unfortunately I can't find where a discussion had place around this. I originally implemented this as two separate plugins but there was enough cross-over with the original plugin to have them in a single plugin. Maybe it makes sense to have them separate, i'm indifferent here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small comments
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
Codecov Report
@@ Coverage Diff @@
## master #10020 +/- ##
========================================
+ Coverage 89% 93% +4%
========================================
Files 179 180 +1
Lines 15810 15852 +42
========================================
+ Hits 14021 14676 +655
+ Misses 1789 1176 -613 |
What does this PR do?
Part of #9287
This is necessary because the
optimizer_step
will be moved to thePrecisionPlugin
but it currently lives in theTPUAccelerator
.This PR just introduces the empty class.
Does your PR introduce any breaking changes? If yes, please list them.
None
Before submitting
PR review