From 17a4920f17f19b4bd175da3b804a84c4c48f827c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 15 Jun 2023 22:12:43 +0200 Subject: [PATCH] Fix bug in macro average for a number of classification metrics (#1821) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka --- CHANGELOG.md | 3 + src/torchmetrics/classification/f_beta.py | 4 +- .../classification/precision_recall.py | 4 +- .../classification/specificity.py | 4 +- .../functional/classification/accuracy.py | 7 +- .../functional/classification/f_beta.py | 10 ++- .../functional/classification/hamming.py | 7 +- .../functional/classification/jaccard.py | 5 +- .../classification/precision_recall.py | 16 +++-- .../functional/classification/specificity.py | 10 ++- src/torchmetrics/utilities/compute.py | 16 ++++- tests/unittests/classification/inputs.py | 65 +++++++++++++------ .../unittests/classification/test_accuracy.py | 31 +++++++-- tests/unittests/classification/test_f_beta.py | 23 +++++-- .../classification/test_hamming_distance.py | 16 +++-- .../unittests/classification/test_jaccard.py | 25 +++++++ .../classification/test_precision_recall.py | 38 +++++++++-- .../classification/test_specificity.py | 27 ++++++-- 18 files changed, 228 insertions(+), 83 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eae22a7eb84..6e355c900ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -218,6 +218,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support for half precision in `PearsonCorrCoef` ([#1819](https://github.com/Lightning-AI/torchmetrics/pull/1819)) +- Fixed number of bugs related to `average="macro"` in classification metrics ([#1821](https://github.com/Lightning-AI/torchmetrics/pull/1821)) + + ## [0.11.4] - 2023-03-10 ### Fixed diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 2ff0f47e749..52710cda19b 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -488,7 +488,9 @@ def __init__( def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() - return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average) + return _fbeta_reduce( + tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average, multilabel=True + ) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 1aeebf3bd1f..27610abd2fb 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -405,7 +405,7 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average + "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True ) def plot( @@ -819,7 +819,7 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() return _precision_recall_reduce( - "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average + "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True ) def plot( diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 456d948b910..283e82e001d 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -383,7 +383,9 @@ class MultilabelSpecificity(MultilabelStatScores): def compute(self) -> Tensor: """Compute metric.""" tp, fp, tn, fn = self._final_state() - return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + return _specificity_reduce( + tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + ) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 61bf8b8b860..efd3686ad6a 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -31,7 +31,7 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, ) -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide from torchmetrics.utilities.enums import ClassificationTask @@ -83,10 +83,7 @@ def _accuracy_reduce( return _safe_divide(tp, tp + fn) score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn) - if average is None or average == "none": - return score - weights = tp + fn if average == "weighted" else torch.ones_like(score) - return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn) def binary_accuracy( diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index e04a58e31fe..4bbab58ef12 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -31,7 +31,7 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, ) -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide from torchmetrics.utilities.enums import ClassificationTask @@ -43,6 +43,7 @@ def _fbeta_reduce( beta: float, average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, ) -> Tensor: beta2 = beta**2 if average == "binary": @@ -54,10 +55,7 @@ def _fbeta_reduce( return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) - if average is None or average == "none": - return fbeta_score - weights = tp + fn if average == "weighted" else torch.ones_like(fbeta_score) - return _safe_divide(weights * fbeta_score, weights.sum(-1, keepdim=True)).sum(-1) + return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fp, fn) def _binary_fbeta_score_arg_validation( @@ -375,7 +373,7 @@ def multilabel_fbeta_score( _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) - return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average) + return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average, multilabel=True) def binary_f1_score( diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 844b318359b..7e9895baa70 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -31,7 +31,7 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, ) -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide from torchmetrics.utilities.enums import ClassificationTask @@ -80,10 +80,7 @@ def _hamming_distance_reduce( return 1 - _safe_divide(tp, tp + fn) score = 1 - _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else 1 - _safe_divide(tp, tp + fn) - if average is None or average == "none": - return score - weights = tp + fn if average == "weighted" else torch.ones_like(score) - return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn) def binary_hamming_distance( diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 6cc9e6aaf5d..59c946ffe27 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -66,7 +66,8 @@ def _jaccard_index_reduce( return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]) ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0] - if confmat.ndim == 3: # multilabel + multilabel = confmat.ndim == 3 + if multilabel: num = confmat[:, 1, 1] denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0] else: # multiclass @@ -87,6 +88,8 @@ def _jaccard_index_reduce( weights = torch.ones_like(jaccard) if ignore_index_cond: weights[ignore_index] = 0.0 + if not multilabel: + weights[confmat.sum(1) + confmat.sum(0) == 0] = 0.0 return ((weights * jaccard) / weights.sum()).sum() diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 9718a75f97c..5b37fc1ab1c 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -31,7 +31,7 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, ) -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide from torchmetrics.utilities.enums import ClassificationTask @@ -43,6 +43,7 @@ def _precision_recall_reduce( fn: Tensor, average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, ) -> Tensor: different_stat = fp if stat == "precision" else fn # this is what differs between the two scores if average == "binary": @@ -54,10 +55,7 @@ def _precision_recall_reduce( return _safe_divide(tp, tp + different_stat) score = _safe_divide(tp, tp + different_stat) - if average is None or average == "none": - return score - weights = tp + fn if average == "weighted" else torch.ones_like(score) - return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn) def binary_precision( @@ -336,7 +334,9 @@ def multilabel_precision( _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) - return _precision_recall_reduce("precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average) + return _precision_recall_reduce( + "precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True + ) def binary_recall( @@ -615,7 +615,9 @@ def multilabel_recall( _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) - return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average) + return _precision_recall_reduce( + "recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True + ) def precision( diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 2096b3ff4ad..472f0efe038 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -31,7 +31,7 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, ) -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide from torchmetrics.utilities.enums import ClassificationTask @@ -42,6 +42,7 @@ def _specificity_reduce( fn: Tensor, average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, ) -> Tensor: if average == "binary": return _safe_divide(tn, tn + fp) @@ -51,10 +52,7 @@ def _specificity_reduce( return _safe_divide(tn, tn + fp) specificity_score = _safe_divide(tn, tn + fp) - if average is None or average == "none": - return specificity_score - weights = tp + fn if average == "weighted" else torch.ones_like(specificity_score) - return _safe_divide(weights * specificity_score, weights.sum(-1, keepdim=True)).sum(-1) + return _adjust_weights_safe_divide(specificity_score, average, multilabel, tp, fp, fn) def binary_specificity( @@ -333,7 +331,7 @@ def multilabel_specificity( _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) - return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) + return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) def specificity( diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 4213af8f46c..71a49057706 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import torch from torch import Tensor @@ -55,6 +55,20 @@ def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: return num / denom +def _adjust_weights_safe_divide( + score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor +) -> Tensor: + if average is None or average == "none": + return score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(score) + if not multilabel: + weights[tp + fp + fn == 0] = 0.0 + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + + def _auc_format_inputs(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: """Check that auc input is correct.""" x = x.squeeze() if x.ndim > 1 else x diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/inputs.py index 14d35fb723d..7d1ae3be8f8 100644 --- a/tests/unittests/classification/inputs.py +++ b/tests/unittests/classification/inputs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple +from typing import Any import pytest import torch @@ -78,85 +79,107 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-labels]", + id="input[single_dim-labels]", ), pytest.param( Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), - id="input[single dim-probs]", + id="input[single_dim-probs]", ), pytest.param( Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-logits]", + id="input[single_dim-logits]", ), pytest.param( Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), - id="input[multi dim-labels]", + id="input[multi_dim-labels]", ), pytest.param( Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), - id="input[multi dim-probs]", + id="input[multi_dim-probs]", ), pytest.param( Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), - id="input[multi dim-logits]", + id="input[multi_dim-logits]", ), ) +def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): + """Generate multiclass input where a class is missing. + + Args: + shape: shape of the tensor + num_classes: number of classes + + Returns: + tensor with missing classes + """ + x = torch.randint(0, num_classes, shape) + x[x == 0] = 2 + return x + + _multiclass_cases = ( pytest.param( Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-labels]", + id="input[single_dim-labels]", ), pytest.param( Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(-1), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-probs]", + id="input[single_dim-probs]", ), pytest.param( Input( preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-logits]", + id="input[single_dim-logits]", ), pytest.param( Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), - id="input[multi dim-labels]", + id="input[multi_dim-labels]", ), pytest.param( Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM).softmax(-2), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), - id="input[multi dim-probs]", + id="input[multi_dim-probs]", ), pytest.param( Input( preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), -2), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), - id="input[multi dim-logits]", + id="input[multi_dim-logits]", + ), + pytest.param( + Input( + preds=_multiclass_with_missing_class(NUM_BATCHES, BATCH_SIZE, num_classes=NUM_CLASSES), + target=_multiclass_with_missing_class(NUM_BATCHES, BATCH_SIZE, num_classes=NUM_CLASSES), + ), + id="input[single_dim-labels-missing_class]", ), ) @@ -167,42 +190,42 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ), - id="input[single dim-labels]", + id="input[single_dim-labels]", ), pytest.param( Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ), - id="input[single dim-probs]", + id="input[single_dim-probs]", ), pytest.param( Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ), - id="input[single dim-logits]", + id="input[single_dim-logits]", ), pytest.param( Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ), - id="input[multi dim-labels]", + id="input[multi_dim-labels]", ), pytest.param( Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ), - id="input[multi dim-probs]", + id="input[multi_dim-probs]", ), pytest.param( Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ), - id="input[multi dim-logits]", + id="input[multi_dim-logits]", ), ) @@ -214,7 +237,7 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-labels]", + id="input[single_dim-labels]", ), pytest.param( GroupInput( @@ -222,7 +245,7 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-probs]", + id="input[single_dim-probs]", ), pytest.param( GroupInput( @@ -230,7 +253,7 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ), - id="input[single dim-logits]", + id="input[single_dim-logits]", ), ) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index f98bc58cf50..9189b897b74 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -98,7 +98,7 @@ def test_accuracy_functional_raises_invalid_task(): class TestBinaryAccuracy(MetricTester): """Test class for `BinaryAccuracy` metric.""" - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [False, True]) def test_binary_accuracy(self, ddp, inputs, ignore_index, multidim_average): @@ -122,7 +122,7 @@ def test_binary_accuracy(self, ddp, inputs, ignore_index, multidim_average): metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) def test_binary_accuracy_functional(self, inputs, ignore_index, multidim_average): """Test functional implementation of metric.""" @@ -201,6 +201,9 @@ def _sklearn_accuracy_multiclass(preds, target, ignore_index, multidim_average, acc_per_class = confmat.diagonal() / confmat.sum(axis=1) acc_per_class[np.isnan(acc_per_class)] = 0.0 if average == "macro": + acc_per_class = acc_per_class[ + (np.bincount(preds, minlength=NUM_CLASSES) + np.bincount(target, minlength=NUM_CLASSES)) != 0.0 + ] return acc_per_class.mean() if average == "weighted": weights = confmat.sum(1) @@ -221,7 +224,10 @@ def _sklearn_accuracy_multiclass(preds, target, ignore_index, multidim_average, acc_per_class = confmat.diagonal() / confmat.sum(axis=1) acc_per_class[np.isnan(acc_per_class)] = 0.0 if average == "macro": - res.append(acc_per_class.mean()) + acc_per_class = acc_per_class[ + (np.bincount(pred, minlength=NUM_CLASSES) + np.bincount(true, minlength=NUM_CLASSES)) != 0.0 + ] + res.append(acc_per_class.mean() if len(acc_per_class) > 0 else 0.0) elif average == "weighted": weights = confmat.sum(1) score = ((weights * acc_per_class) / weights.sum()).sum() @@ -433,7 +439,7 @@ class TestMultilabelAccuracy(MetricTester): """Test class for `MultilabelAccuracy` metric.""" @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multilabel_accuracy(self, ddp, inputs, ignore_index, multidim_average, average): @@ -466,7 +472,7 @@ def test_multilabel_accuracy(self, ddp, inputs, ignore_index, multidim_average, }, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multilabel_accuracy_functional(self, inputs, ignore_index, multidim_average, average): @@ -536,3 +542,18 @@ def test_multilabel_accuracy_half_gpu(self, inputs, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +def test_corner_cases(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1691.""" + # simulate the output of a perfect predictor (i.e. preds == target) + target = torch.tensor([0, 1, 2, 0, 1, 2]) + preds = target + + metric = MulticlassAccuracy(num_classes=3, average="none", ignore_index=0) + res = metric(preds, target) + assert torch.allclose(res, torch.tensor([0.0, 1.0, 1.0])) + + metric = MulticlassAccuracy(num_classes=3, average="macro", ignore_index=0) + res = metric(preds, target) + assert res == 1.0 diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 949982f956b..3a692d33c1e 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -84,7 +84,7 @@ def _sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, multidim_ave class TestBinaryFBetaScore(MetricTester): """Test class for `BinaryFBetaScore` metric.""" - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [False, True]) def test_binary_fbeta_score(self, ddp, inputs, module, functional, compare, ignore_index, multidim_average): @@ -108,7 +108,7 @@ def test_binary_fbeta_score(self, ddp, inputs, module, functional, compare, igno metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) def test_binary_fbeta_score_functional(self, inputs, module, functional, compare, ignore_index, multidim_average): """Test functional implementation of metric.""" @@ -181,7 +181,7 @@ def _sklearn_fbeta_score_multiclass(preds, target, sk_fn, ignore_index, multidim preds = preds.numpy().flatten() target = target.numpy().flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds, average=average) + return sk_fn(target, preds, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) preds = preds.numpy() target = target.numpy() @@ -190,7 +190,8 @@ def _sklearn_fbeta_score_multiclass(preds, target, sk_fn, ignore_index, multidim pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)))) + r = sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) + res.append(0.0 if np.isnan(r).any() else r) return np.stack(res, 0) @@ -457,7 +458,7 @@ class TestMultilabelFBetaScore(MetricTester): """Test class for `MultilabelFBetaScore` metric.""" @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multilabel_fbeta_score( @@ -493,7 +494,7 @@ def test_multilabel_fbeta_score( }, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multilabel_fbeta_score_functional( @@ -566,3 +567,13 @@ def test_multilabel_fbeta_score_half_gpu(self, inputs, module, functional, compa metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +def test_corner_case(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1664.""" + target = torch.tensor([2, 1, 0, 0]) + preds = torch.tensor([2, 1, 0, 1]) + for i in range(3, 9): + f1_score = MulticlassF1Score(num_classes=i, average="macro") + res = f1_score(preds, target) + assert res == torch.tensor([0.77777779]) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 28bba43cf47..0a59a7902e6 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -73,7 +73,7 @@ def _sklearn_hamming_distance_binary(preds, target, ignore_index, multidim_avera class TestBinaryHammingDistance(MetricTester): """Test class for `BinaryHammingDistance` metric.""" - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [False, True]) def test_binary_hamming_distance(self, ddp, inputs, ignore_index, multidim_average): @@ -97,7 +97,7 @@ def test_binary_hamming_distance(self, ddp, inputs, ignore_index, multidim_avera metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) def test_binary_hamming_distance_functional(self, inputs, ignore_index, multidim_average): """Test functional implementation of metric.""" @@ -172,6 +172,9 @@ def _sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, ave hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) hamming_per_class[np.isnan(hamming_per_class)] = 1.0 if average == "macro": + hamming_per_class = hamming_per_class[ + (np.bincount(preds, minlength=NUM_CLASSES) + np.bincount(target, minlength=NUM_CLASSES)) != 0.0 + ] return hamming_per_class.mean() if average == "weighted": weights = confmat.sum(1) @@ -194,7 +197,10 @@ def _sklearn_hamming_distance_multiclass_local(preds, target, ignore_index, aver hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) hamming_per_class[np.isnan(hamming_per_class)] = 1.0 if average == "macro": - res.append(hamming_per_class.mean()) + hamming_per_class = hamming_per_class[ + (np.bincount(pred, minlength=NUM_CLASSES) + np.bincount(true, minlength=NUM_CLASSES)) != 0.0 + ] + res.append(hamming_per_class.mean() if len(hamming_per_class) > 0 else 0.0) elif average == "weighted": weights = confmat.sum(1) score = ((weights * hamming_per_class) / weights.sum()).sum() @@ -399,7 +405,7 @@ class TestMultilabelHammingDistance(MetricTester): """Test class for `MultilabelHammingDistance` metric.""" @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) def test_multilabel_hamming_distance(self, ddp, inputs, ignore_index, multidim_average, average): @@ -432,7 +438,7 @@ def test_multilabel_hamming_distance(self, ddp, inputs, ignore_index, multidim_a }, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) def test_multilabel_hamming_distance_functional(self, inputs, ignore_index, multidim_average, average): diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 3af62d28d78..5a2ccc1fa43 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -135,6 +135,8 @@ def _sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average= labels = [i for i in range(NUM_CLASSES) if i != ignore_index] res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels) return np.insert(res, ignore_index, 0.0) if average is None else res + if average is None: + return sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=list(range(NUM_CLASSES))) return sk_jaccard_index(y_true=target, y_pred=preds, average=average) @@ -334,3 +336,26 @@ def test_multilabel_jaccard_index_dtype_gpu(self, inputs, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +def test_corner_case(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1693.""" + # edge case: class 2 is not present in the target AND the prediction + target = torch.tensor([0, 1, 0, 0]) + preds = torch.tensor([0, 1, 0, 1]) + + metric = MulticlassJaccardIndex(num_classes=3, average="none") + res = metric(preds, target) + assert torch.allclose(res, torch.tensor([2.0 / 3.0, 0.5000, 0.0000])) + + metric = MulticlassJaccardIndex(num_classes=3, average="macro") + res = metric(preds, target) + assert torch.allclose(res, torch.tensor(0.5833333)) + + target = torch.tensor([0, 1]) + pred = torch.tensor([0, 1]) + out = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0]).float() + res = multiclass_jaccard_index(pred, target, num_classes=10) + assert torch.allclose(res, torch.ones_like(res)) + res = multiclass_jaccard_index(pred, target, num_classes=10, average="none") + assert torch.allclose(res, out) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 6aa4b65c0b7..5fda29f1e30 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -84,7 +84,7 @@ def _sklearn_precision_recall_binary(preds, target, sk_fn, ignore_index, multidi class TestBinaryPrecisionRecall(MetricTester): """Test class for `BinaryPrecisionRecall` metric.""" - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [False, True]) def test_binary_precision_recall(self, ddp, inputs, module, functional, compare, ignore_index, multidim_average): @@ -111,7 +111,7 @@ def test_binary_precision_recall(self, ddp, inputs, module, functional, compare, metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) def test_binary_precision_recall_functional( self, inputs, module, functional, compare, ignore_index, multidim_average @@ -184,11 +184,12 @@ def test_binary_precision_recall_half_gpu(self, inputs, module, functional, comp def _sklearn_precision_recall_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) + if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds, average=average) + return sk_fn(target, preds, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) preds = preds.numpy() target = target.numpy() @@ -197,7 +198,9 @@ def _sklearn_precision_recall_multiclass(preds, target, sk_fn, ignore_index, mul pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)))) + r = sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)) if average is None else None) + res.append(0.0 if np.isnan(r).any() else r) + return np.stack(res, 0) @@ -451,7 +454,7 @@ class TestMultilabelPrecisionRecall(MetricTester): """Test class for `MultilabelPrecisionRecall` metric.""" @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multilabel_precision_recall( @@ -487,7 +490,7 @@ def test_multilabel_precision_recall( }, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multilabel_precision_recall_functional( @@ -559,3 +562,26 @@ def test_multilabel_precision_recall_half_gpu(self, inputs, module, functional, metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +def test_corner_case(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1692.""" + # simulate the output of a perfect predictor (i.e. preds == target) + target = torch.tensor([0, 1, 2, 0, 1, 2]) + preds = target.clone() + + metric = MulticlassPrecision(num_classes=3, average="none", ignore_index=0) + res = metric(preds, target) + assert torch.allclose(res, torch.tensor([0.0, 1.0, 1.0])) + + metric = MulticlassRecall(num_classes=3, average="none", ignore_index=0) + res = metric(preds, target) + assert torch.allclose(res, torch.tensor([0.0, 1.0, 1.0])) + + metric = MulticlassPrecision(num_classes=3, average="macro", ignore_index=0) + res = metric(preds, target) + assert res == 1.0 + + metric = MulticlassRecall(num_classes=3, average="macro", ignore_index=0) + res = metric(preds, target) + assert res == 1.0 diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index a7b9fda64ef..12bdd57edb1 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -82,7 +82,7 @@ def _baseline_specificity_binary(preds, target, ignore_index, multidim_average): class TestBinarySpecificity(MetricTester): """Test class for `BinarySpecificity` metric.""" - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [False, True]) def test_binary_specificity(self, ddp, inputs, ignore_index, multidim_average): @@ -106,7 +106,7 @@ def test_binary_specificity(self, ddp, inputs, ignore_index, multidim_average): metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) def test_binary_specificity_functional(self, inputs, ignore_index, multidim_average): """Test functional implementation of metric.""" @@ -190,6 +190,7 @@ def _baseline_specificity_multiclass_global(preds, target, ignore_index, average res = _calc_specificity(tn, fp) if average == "macro": + res = res[(np.bincount(preds, minlength=NUM_CLASSES) + np.bincount(target, minlength=NUM_CLASSES)) != 0.0] return res.mean(0) if average == "weighted": w = tp + fn @@ -222,7 +223,8 @@ def _baseline_specificity_multiclass_local(preds, target, ignore_index, average) r = _calc_specificity(tn, fp) if average == "macro": - res.append(r.mean(0)) + r = r[(np.bincount(pred, minlength=NUM_CLASSES) + np.bincount(true, minlength=NUM_CLASSES)) != 0.0] + res.append(r.mean(0) if len(r) > 0 else 0.0) elif average == "weighted": w = tp + fn res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) @@ -446,7 +448,7 @@ class TestMultilabelSpecificity(MetricTester): """Test class for `MultilabelSpecificity` metric.""" @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) def test_multilabel_specificity(self, ddp, inputs, ignore_index, multidim_average, average): @@ -479,7 +481,7 @@ def test_multilabel_specificity(self, ddp, inputs, ignore_index, multidim_averag }, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) def test_multilabel_specificity_functional(self, inputs, ignore_index, multidim_average, average): @@ -548,3 +550,18 @@ def test_multilabel_specificity_dtype_gpu(self, inputs, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +def test_corner_cases(): + """Test corner cases for specificity metric.""" + # simulate the output of a perfect predictor (i.e. preds == target) + target = torch.tensor([0, 1, 2, 0, 1, 2]) + preds = target + + metric = MulticlassSpecificity(num_classes=3, average="none", ignore_index=0) + res = metric(preds, target) + assert torch.allclose(res, torch.tensor([1.0, 1.0, 1.0])) + + metric = MulticlassSpecificity(num_classes=3, average="macro", ignore_index=0) + res = metric(preds, target) + assert res == 1.0