diff --git a/CHANGELOG.md b/CHANGELOG.md index fcf702137b3..500410552de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added pre-gather reduction in the case of `dist_reduce_fx="cat"` to reduce communication cost ([#217](https://github.com/PyTorchLightning/metrics/pull/217)) +- Added better error message for `AUROC` when `num_classes` is not provided for multiclass input ([#244](https://github.com/PyTorchLightning/metrics/pull/244)) + + - Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200)) diff --git a/tests/classification/test_auroc.py b/tests/classification/test_auroc.py index d3196e1ef9d..32edc04e10f 100644 --- a/tests/classification/test_auroc.py +++ b/tests/classification/test_auroc.py @@ -188,3 +188,11 @@ def test_error_on_different_mode(): with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): # pass in multi-label data metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) + + +def test_error_multiclass_no_num_classes(): + with pytest.raises( + ValueError, + match="Detected input to ``multiclass`` but you did not provide ``num_classes`` argument" + ): + _ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20, ))) diff --git a/torchmetrics/functional/classification/auroc.py b/torchmetrics/functional/classification/auroc.py index c8a92737079..7dc0739f0fb 100644 --- a/torchmetrics/functional/classification/auroc.py +++ b/torchmetrics/functional/classification/auroc.py @@ -85,6 +85,8 @@ def _auroc_compute( fpr = [o[0] for o in output] tpr = [o[1] for o in output] else: + if mode != 'binary' and num_classes is None: + raise ValueError('Detected input to ``multiclass`` but you did not provide ``num_classes`` argument') fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) # calculate standard roc auc score