-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSpeed still changes metric states from fp32 to fp16 #1561
Comments
Hi! thanks for your contribution!, great first issue! |
@justusschock do you have experience with the internals of DeepSpeed? Is this something we can get around or is this a limitation? |
@SkafteNicki I don't have too much experience with deepspeed internals, but I'll have a look why/where this happens. |
@justusschock Hi! One thing I noticed is that the |
Hey @FarzanT , |
@justusschock import torch
from pytorch_lightning import LightningModule
from torchmetrics import PearsonCorrCoef, MeanAbsoluteError
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 32)
self.pearson = PearsonCorrCoef()
self.mae = MeanAbsoluteError()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
pred = self.forward(batch)
loss = self(batch).sum()
self.metric.update(torch.flatten(pred), torch.flatten(batch))
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
model = BoringModel()
print(model.mae.sum_abs_error.dtype)
print(model.pearson.mean_x.dtype)
model = model.half()
print(model.mae.sum_abs_error.dtype)
print(model.pearson.mean_x.dtype)
model = model.float()
print(model.mae.sum_abs_error.dtype)
print(model.pearson.mean_x.dtype) Output:
Any pointers on how to protect the metrics states from this? |
@justusschock @SkafteNicki I think I figured out the source of the problem: May I open a pull request? |
🐛 Bug
Following #484, the PR #493 introduced
set_dtype()
to prevent.half()
calls to change the precision of metric states. However, at least forPearsonCorrCoef
, DeepSpeed still somehow modifies thedtype
. During initialization the metric states have the default dtype oftorch.float32
. However, this changes as soon as DeepSpeed is initialized (refer to the code example below).I tried this with
MeanAbsoluteError
as well, and interestingly, out ofsum_abs_error
andtotal
metric states, onlysum_abs_error
changes fromtorch.float32
totorch.float16
. Thetotal
metric state remains astorch.int64
. So DeepSpeed is probably only converting floats and not ints.This is especially problematic for epoch level metrics such as
PearsonCorrCoef
, since the numbers they hold can easily become larger the the 65,536 maximum allowed with fp16 precision.To Reproduce
Refer to the minimal code sample.
Code sample
Output on my setup
Expected behavior
DeepSpeed should not affect
dtype
of metric states, that is, they should remain intorch.float32
even if DeepSpeed withprecision=16
is used.Environment
collect_env results
Additional context
The text was updated successfully, but these errors were encountered: