Skip to content

Commit

Permalink
Remove error in ROC and AUROC (#583)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Oct 27, 2021
1 parent e8ad72e commit a23916b
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493))
- Renamed `AverageMeter` to `MeanMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506))
- Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551))
- `ROC` and `AUROC` will no longer throw an error when either the positive or negative class is missing. Instead return 0 score and give a warning

### Deprecated

Expand Down
21 changes: 21 additions & 0 deletions tests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,24 @@ def test_weighted_with_empty_classes():
target = torch.zeros_like(target)
with pytest.raises(ValueError, match="Found 1 non-empty class in `multiclass` AUROC calculation"):
_ = auroc(preds, target, average="weighted", num_classes=num_classes + 1)


def test_warnings_on_missing_class():
"""Test that a warning is given if either the positive or negative class is missing."""
metric = AUROC()
# no positive samples
warning = (
"No positive samples in targets, true positive value should be meaningless."
" Returning zero tensor in true positive score"
)
with pytest.warns(UserWarning, match=warning):
score = metric(torch.randn(10).sigmoid(), torch.zeros(10).int())
assert score == 0

warning = (
"No negative samples in targets, false positive value should be meaningless."
" Returning zero tensor in false positive score"
)
with pytest.warns(UserWarning, match=warning):
score = metric(torch.randn(10).sigmoid(), torch.ones(10).int())
assert score == 0
21 changes: 21 additions & 0 deletions tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,24 @@ def test_roc_curve(pred, target, expected_tpr, expected_fpr):
assert fpr.size(0) == thresh.size(0)
assert torch.allclose(fpr, tensor(expected_fpr).to(fpr))
assert torch.allclose(tpr, tensor(expected_tpr).to(tpr))


def test_warnings_on_missing_class():
"""Test that a warning is given if either the positive or negative class is missing."""
metric = ROC()
# no positive samples
warning = (
"No positive samples in targets, true positive value should be meaningless."
" Returning zero tensor in true positive score"
)
with pytest.warns(UserWarning, match=warning):
_, tpr, _ = metric(torch.randn(10).sigmoid(), torch.zeros(10))
assert all(tpr == 0)

warning = (
"No negative samples in targets, false positive value should be meaningless."
" Returning zero tensor in false positive score"
)
with pytest.warns(UserWarning, match=warning):
fpr, _, _ = metric(torch.randn(10).sigmoid(), torch.ones(10))
assert all(fpr == 0)
5 changes: 5 additions & 0 deletions torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class AUROC(Metric):
dimension more than the ``target`` tensor the input will be interpretated as
multiclass.
.. note::
If either the positive class or negative class is completly missing in the target tensor,
the auroc score is meaningless in this case and a score of 0 will be returned together
with an warning.
Args:
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems
Expand Down
5 changes: 5 additions & 0 deletions torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class ROC(Metric):
- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
.. note::
If either the positive class or negative class is completly missing in the target tensor,
the roc values are not well defined in this case and a tensor of zeros will be returned (either fpr
or tpr depending on what class is missing) together with an warning.
Args:
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems
Expand Down
10 changes: 10 additions & 0 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ def auroc(
) -> Tensor:
"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_)
For non-binary input, if the ``preds`` and ``target`` tensor have the same
size the input will be interpretated as multilabel and if ``preds`` have one
dimension more than the ``target`` tensor the input will be interpretated as
multiclass.
.. note::
If either the positive class or negative class is completly missing in the target tensor,
the auroc score is meaningless in this case and a score of 0 will be returned together
with an warning.
Args:
preds: predictions from model (logits or probabilities)
target: Ground truth labels
Expand Down
26 changes: 22 additions & 4 deletions torchmetrics/functional/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_binary_clf_curve,
_precision_recall_curve_update,
)
from torchmetrics.utilities import rank_zero_warn


def _roc_update(
Expand Down Expand Up @@ -73,12 +74,24 @@ def _roc_compute_single_class(
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]
rank_zero_warn(
"No negative samples in targets, false positive value should be meaningless."
" Returning zero tensor in false positive score",
UserWarning,
)
fpr = torch.zeros_like(thresholds)
else:
fpr = fps / fps[-1]

if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]
rank_zero_warn(
"No positive samples in targets, true positive value should be meaningless."
" Returning zero tensor in true positive score",
UserWarning,
)
tpr = torch.zeros_like(thresholds)
else:
tpr = tps / tps[-1]

return fpr, tpr, thresholds

Expand Down Expand Up @@ -196,6 +209,11 @@ def roc(
"""Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel
input.
.. note::
If either the positive class or negative class is completly missing in the target tensor,
the roc values are not well defined in this case and a tensor of zeros will be returned (either fpr
or tpr depending on what class is missing) together with an warning.
Args:
preds: predictions from model (logits or probabilities)
target: ground truth values
Expand Down

0 comments on commit a23916b

Please sign in to comment.