You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If you're running a model with torch.cuda.amp.autocast to bf16, you may have model outputs in bf16 and labels in fp32, and then run metric.compute() outside of the autocast. Everything works completely fine for most (perhaps all other?) metrics, but SpearmanCorrCoef has an explicit check that the dtype of the preds and target are the same (https://github.com/Lightning-AI/metrics/blob/70a844f5aa598035eae50f3268563cfab103c62d/src/torchmetrics/functional/regression/spearman.py#L65). I do not think this check is necessary, because torch will automatically promote one of them when they are multiplied together, which is the only operation between the two tensors that happens while computing spearman. I may be missing something, but it would be great to remove this explicit check so that code using this metric does not need to explicitly cast the inputs, or to just handle the casting inside the metric if it is necessary for some reason.
To Reproduce
In [1]: import torch
...: from torchmetrics import MeanSquaredError, SpearmanCorrCoef
...:
...: preds = torch.rand((100,), dtype=torch.bfloat16)
...: target = torch.rand((100,), dtype=torch.float)
...: fp32preds = preds.detach().clone().float()
...:
...: sp1 = SpearmanCorrCoef()
...: sp2 = SpearmanCorrCoef()
...:
...: # Spearman update errors
...: sp1.update(preds, target)
/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-1-162d7ed78d22> in <cell line: 12>()
10
11 # Spearman update errors
---> 12 sp1.update(preds, target)
13 sp2.update(fp32preds, target)
14 print(sp1.compute())
/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/metric.py in wrapped_func(*args, **kwargs)
265 self._update_called = True
266 with torch.set_grad_enabled(self._enable_grad):
--> 267 return update(*args, **kwargs)
268
269 return wrapped_func
/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/regression/spearman.py in update(self, preds, target)
88 target: Ground truth values
89 """
---> 90 preds, target = _spearman_corrcoef_update(preds, target)
91 self.preds.append(preds)
92 self.target.append(target)
/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/functional/regression/spearman.py in _spearman_corrcoef_update(preds, target)
62
63 if preds.dtype != target.dtype:
---> 64 raise TypeError(
65 "Expected `preds` and `target` to have the same data type."
66 f" Got preds: {preds.dtype} and target: {target.dtype}."
TypeError: Expected `preds` and `target` to have the same data type. Got preds: torch.bfloat16 and target: torch.float32.
and if you comment out the dtype check
In [1]: import torch
...: from torchmetrics import MeanSquaredError, SpearmanCorrCoef
...:
...: preds = torch.rand((100,), dtype=torch.bfloat16)
...: target = torch.rand((100,), dtype=torch.float)
...: fp32preds = preds.detach().clone().float()
...:
...: sp1 = SpearmanCorrCoef()
...: sp2 = SpearmanCorrCoef()
...:
...: # Spearman update errors
...: sp1.update(preds, target)
...: sp2.update(fp32preds, target)
...:
...:
...: print(sp1.compute())
...: print(sp2.compute())
/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
tensor(-0.0699)
tensor(-0.0699)
Code sample
See above
Expected behavior
Spearman computation works with preds and target having different dtype.
Environment
Checked that it is an issue on 0.10.1, and the check still exists on master (linked above)
The text was updated successfully, but these errors were encountered:
Hi @dakinggg, thanks for raising this issue. We can agree that the dtype check is probably too restrictive but I still think it would probably be a good idea to check that preds.is_floating_point() == target.is_floating_point() such that we are not mixing float with int values. What do you say to this?
Sure, that seems fair, I think the current SpearmanCorrCoef implementation actually just crashes if you you use int right now, because of the .mean() call. Might want to add some check around that too, but just allowing different float types definitely solves my problem.
🐛 Bug
If you're running a model with
torch.cuda.amp.autocast
tobf16
, you may have model outputs inbf16
and labels infp32
, and then runmetric.compute()
outside of the autocast. Everything works completely fine for most (perhaps all other?) metrics, butSpearmanCorrCoef
has an explicit check that thedtype
of thepreds
andtarget
are the same (https://github.com/Lightning-AI/metrics/blob/70a844f5aa598035eae50f3268563cfab103c62d/src/torchmetrics/functional/regression/spearman.py#L65). I do not think this check is necessary, because torch will automatically promote one of them when they are multiplied together, which is the only operation between the two tensors that happens while computing spearman. I may be missing something, but it would be great to remove this explicit check so that code using this metric does not need to explicitly cast the inputs, or to just handle the casting inside the metric if it is necessary for some reason.To Reproduce
and if you comment out the dtype check
Code sample
See above
Expected behavior
Spearman computation works with
preds
andtarget
having differentdtype
.Environment
Checked that it is an issue on
0.10.1
, and the check still exists on master (linked above)The text was updated successfully, but these errors were encountered: