Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpearmanCorrCoef has unnecessary explicit dtype check #1294

Closed
dakinggg opened this issue Oct 26, 2022 · 3 comments · Fixed by #1303
Closed

SpearmanCorrCoef has unnecessary explicit dtype check #1294

dakinggg opened this issue Oct 26, 2022 · 3 comments · Fixed by #1303
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@dakinggg
Copy link

🐛 Bug

If you're running a model with torch.cuda.amp.autocast to bf16, you may have model outputs in bf16 and labels in fp32, and then run metric.compute() outside of the autocast. Everything works completely fine for most (perhaps all other?) metrics, but SpearmanCorrCoef has an explicit check that the dtype of the preds and target are the same (https://github.com/Lightning-AI/metrics/blob/70a844f5aa598035eae50f3268563cfab103c62d/src/torchmetrics/functional/regression/spearman.py#L65). I do not think this check is necessary, because torch will automatically promote one of them when they are multiplied together, which is the only operation between the two tensors that happens while computing spearman. I may be missing something, but it would be great to remove this explicit check so that code using this metric does not need to explicitly cast the inputs, or to just handle the casting inside the metric if it is necessary for some reason.

To Reproduce

In [1]: import torch
   ...: from torchmetrics import MeanSquaredError, SpearmanCorrCoef
   ...: 
   ...: preds = torch.rand((100,), dtype=torch.bfloat16)
   ...: target = torch.rand((100,), dtype=torch.float)
   ...: fp32preds = preds.detach().clone().float()
   ...: 
   ...: sp1 = SpearmanCorrCoef()
   ...: sp2 = SpearmanCorrCoef()
   ...: 
   ...: # Spearman update errors
   ...: sp1.update(preds, target)
/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-162d7ed78d22> in <cell line: 12>()
     10 
     11 # Spearman update errors
---> 12 sp1.update(preds, target)
     13 sp2.update(fp32preds, target)
     14 print(sp1.compute())

/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/metric.py in wrapped_func(*args, **kwargs)
    265             self._update_called = True
    266             with torch.set_grad_enabled(self._enable_grad):
--> 267                 return update(*args, **kwargs)
    268 
    269         return wrapped_func

/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/regression/spearman.py in update(self, preds, target)
     88             target: Ground truth values
     89         """
---> 90         preds, target = _spearman_corrcoef_update(preds, target)
     91         self.preds.append(preds)
     92         self.target.append(target)

/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/functional/regression/spearman.py in _spearman_corrcoef_update(preds, target)
     62 
     63     if preds.dtype != target.dtype:
---> 64         raise TypeError(
     65             "Expected `preds` and `target` to have the same data type."
     66             f" Got preds: {preds.dtype} and target: {target.dtype}."

TypeError: Expected `preds` and `target` to have the same data type. Got preds: torch.bfloat16 and target: torch.float32.

and if you comment out the dtype check

In [1]: import torch
   ...: from torchmetrics import MeanSquaredError, SpearmanCorrCoef
   ...: 
   ...: preds = torch.rand((100,), dtype=torch.bfloat16)
   ...: target = torch.rand((100,), dtype=torch.float)
   ...: fp32preds = preds.detach().clone().float()
   ...: 
   ...: sp1 = SpearmanCorrCoef()
   ...: sp2 = SpearmanCorrCoef()
   ...: 
   ...: # Spearman update errors
   ...: sp1.update(preds, target)
   ...: sp2.update(fp32preds, target)
   ...: 
   ...: 
   ...: print(sp1.compute())
   ...: print(sp2.compute())
/workdisk/danielking/composer_venv/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
tensor(-0.0699)
tensor(-0.0699)

Code sample

See above

Expected behavior

Spearman computation works with preds and target having different dtype.

Environment

Checked that it is an issue on 0.10.1, and the check still exists on master (linked above)

@dakinggg dakinggg added bug / fix Something isn't working help wanted Extra attention is needed labels Oct 26, 2022
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @dakinggg, thanks for raising this issue. We can agree that the dtype check is probably too restrictive but I still think it would probably be a good idea to check that preds.is_floating_point() == target.is_floating_point() such that we are not mixing float with int values. What do you say to this?

@dakinggg
Copy link
Author

Sure, that seems fair, I think the current SpearmanCorrCoef implementation actually just crashes if you you use int right now, because of the .mean() call. Might want to add some check around that too, but just allowing different float types definitely solves my problem.

@SkafteNicki SkafteNicki added this to the v0.10 milestone Oct 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants