Skip to content

Commit

Permalink
Fix autocast with spearman metric (#1303)
Browse files Browse the repository at this point in the history
(cherry picked from commit 604ed80)
  • Loading branch information
SkafteNicki authored and Borda committed Oct 31, 2022
1 parent dccd432 commit 75df8e0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed precision problems when `structural_similarity_index_measure` was used with autocast ([#1291](https://github.com/Lightning-AI/metrics/pull/1291))


- Fixed restrictive dtype checking in `spearman_corrcoef` when used with autocast ([#1303](https://github.com/Lightning-AI/metrics/pull/1303))


## [0.10.1] - 2022-10-21

### Fixed
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/functional/regression/spearman.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -
target: Ground truth tensor
"""

if preds.dtype != target.dtype:
if not (preds.is_floating_point() and target.is_floating_point()):
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got preds: {preds.dtype} and target: {target.dtype}."
"Expected `preds` and `target` both to be floating point tensors, but got {pred.dtype} and {target.dtype}"
)
_check_same_shape(preds, target)
if preds.ndim > 2 or target.ndim > 2:
Expand Down
3 changes: 3 additions & 0 deletions tests/unittests/regression/test_spearman.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def test_spearman_corrcoef_half_gpu(self, preds, target):

def test_error_on_different_shape():
metric = SpearmanCorrCoef(num_outputs=1)
with pytest.raises(TypeError, match="Expected `preds` and `target` both to be floating point tensors.*"):
metric(torch.randint(5, (100,)), torch.randn(100))

with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))

Expand Down

0 comments on commit 75df8e0

Please sign in to comment.