diff --git a/docs/source/classification/auroc.rst b/docs/source/classification/auroc.rst index 6c0a2342b8b..4828f605229 100644 --- a/docs/source/classification/auroc.rst +++ b/docs/source/classification/auroc.rst @@ -15,8 +15,44 @@ ________________ .. autoclass:: torchmetrics.AUROC :noindex: +BinaryAUROC +^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryAUROC + :noindex: + +MulticlassAUROC +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassAUROC + :noindex: + +MultilabelAUROC +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelAUROC + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.auroc :noindex: + +binary_auroc +^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_auroc + :noindex: + +multiclass_auroc +^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_auroc + :noindex: + +multilabel_auroc +^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_auroc + :noindex: diff --git a/docs/source/classification/average_precision.rst b/docs/source/classification/average_precision.rst index 2061c284400..b241f682a86 100644 --- a/docs/source/classification/average_precision.rst +++ b/docs/source/classification/average_precision.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.AveragePrecision :noindex: +BinaryAveragePrecision +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryAveragePrecision + :noindex: + +MulticlassAveragePrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassAveragePrecision + :noindex: + +MultilabelAveragePrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelAveragePrecision + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.average_precision :noindex: + +binary_average_precision +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_average_precision + :noindex: + +multiclass_average_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_average_precision + :noindex: + +multilabel_average_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_average_precision + :noindex: diff --git a/docs/source/classification/precision_recall_curve.rst b/docs/source/classification/precision_recall_curve.rst index bd457727374..412470ccd3d 100644 --- a/docs/source/classification/precision_recall_curve.rst +++ b/docs/source/classification/precision_recall_curve.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.PrecisionRecallCurve :noindex: +BinaryPrecisionRecallCurve +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryPrecisionRecallCurve + :noindex: + +MulticlassPrecisionRecallCurve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassPrecisionRecallCurve + :noindex: + +MultilabelPrecisionRecallCurve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelPrecisionRecallCurve + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.precision_recall_curve :noindex: + +binary_precision_recall_curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_precision_recall_curve + :noindex: + +multiclass_precision_recall_curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_precision_recall_curve + :noindex: + +multilabel_precision_recall_curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_precision_recall_curve + :noindex: diff --git a/docs/source/classification/recall_at_fixed_precision.rst b/docs/source/classification/recall_at_fixed_precision.rst new file mode 100644 index 00000000000..f585f2abc7e --- /dev/null +++ b/docs/source/classification/recall_at_fixed_precision.rst @@ -0,0 +1,50 @@ +.. customcarditem:: + :header: Recall At Fixed Precision + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +######################### +Recall At Fixed Precision +######################### + +Module Interface +________________ + +BinaryRecallAtFixedPrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryRecallAtFixedPrecision + :noindex: + +MulticlassRecallAtFixedPrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassRecallAtFixedPrecision + :noindex: + +MultilabelRecallAtFixedPrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelRecallAtFixedPrecision + :noindex: + +Functional Interface +____________________ + +binary_recall_at_fixed_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_recall_at_fixed_precision + :noindex: + +multiclass_recall_at_fixed_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_recall_at_fixed_precision + :noindex: + +multilabel_recall_at_fixed_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_recall_at_fixed_precision + :noindex: diff --git a/docs/source/classification/roc.rst b/docs/source/classification/roc.rst index 6b3aaea4add..04dae356790 100644 --- a/docs/source/classification/roc.rst +++ b/docs/source/classification/roc.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.ROC :noindex: +BinaryROC +^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryROC + :noindex: + +MulticlassROC +^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassROC + :noindex: + +MultilabelROC +^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelROC + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.roc :noindex: + +binary_roc +^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_roc + :noindex: + +multiclass_roc +^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_roc + :noindex: + +multilabel_roc +^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_roc + :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 63d4ea4a330..122e16391f5 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -27,6 +27,8 @@ Accuracy, AveragePrecision, BinaryAccuracy, + BinaryAUROC, + BinaryAveragePrecision, BinaryCohenKappa, BinaryConfusionMatrix, BinaryF1Score, @@ -35,7 +37,10 @@ BinaryJaccardIndex, BinaryMatthewsCorrCoef, BinaryPrecision, + BinaryPrecisionRecallCurve, BinaryRecall, + BinaryRecallAtFixedPrecision, + BinaryROC, BinarySpecificity, BinaryStatScores, BinnedAveragePrecision, @@ -56,6 +61,8 @@ LabelRankingLoss, MatthewsCorrCoef, MulticlassAccuracy, + MulticlassAUROC, + MulticlassAveragePrecision, MulticlassCohenKappa, MulticlassConfusionMatrix, MulticlassF1Score, @@ -64,10 +71,15 @@ MulticlassJaccardIndex, MulticlassMatthewsCorrCoef, MulticlassPrecision, + MulticlassPrecisionRecallCurve, MulticlassRecall, + MulticlassRecallAtFixedPrecision, + MulticlassROC, MulticlassSpecificity, MulticlassStatScores, MultilabelAccuracy, + MultilabelAUROC, + MultilabelAveragePrecision, MultilabelConfusionMatrix, MultilabelCoverageError, MultilabelExactMatch, @@ -77,9 +89,12 @@ MultilabelJaccardIndex, MultilabelMatthewsCorrCoef, MultilabelPrecision, + MultilabelPrecisionRecallCurve, MultilabelRankingAveragePrecision, MultilabelRankingLoss, MultilabelRecall, + MultilabelRecallAtFixedPrecision, + MultilabelROC, MultilabelSpecificity, MultilabelStatScores, Precision, @@ -155,6 +170,21 @@ "MultilabelAccuracy", "AUC", "AUROC", + "BinaryAUROC", + "BinaryAveragePrecision", + "BinaryPrecisionRecallCurve", + "BinaryRecallAtFixedPrecision", + "BinaryROC", + "MultilabelROC", + "MulticlassAUROC", + "MulticlassAveragePrecision", + "MulticlassPrecisionRecallCurve", + "MulticlassRecallAtFixedPrecision", + "MulticlassROC", + "MultilabelAUROC", + "MultilabelAveragePrecision", + "MultilabelPrecisionRecallCurve", + "MultilabelRecallAtFixedPrecision", "AveragePrecision", "BinnedAveragePrecision", "BinnedPrecisionRecallCurve", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index b1f62ea5e79..59a9d63793b 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -17,6 +17,12 @@ MulticlassConfusionMatrix, MultilabelConfusionMatrix, ) +from torchmetrics.classification.precision_recall_curve import ( # noqa: F401 isort:skip + PrecisionRecallCurve, + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) from torchmetrics.classification.stat_scores import ( # noqa: F401 isort:skip BinaryStatScores, MulticlassStatScores, @@ -31,8 +37,13 @@ MultilabelAccuracy, ) from torchmetrics.classification.auc import AUC # noqa: F401 -from torchmetrics.classification.auroc import AUROC # noqa: F401 -from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401 +from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC # noqa: F401 +from torchmetrics.classification.average_precision import ( # noqa: F401 + AveragePrecision, + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 @@ -80,7 +91,6 @@ Precision, Recall, ) -from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from torchmetrics.classification.ranking import ( # noqa: F401 CoverageError, LabelRankingAveragePrecision, @@ -89,7 +99,12 @@ MultilabelRankingAveragePrecision, MultilabelRankingLoss, ) -from torchmetrics.classification.roc import ROC # noqa: F401 +from torchmetrics.classification.recall_at_fixed_precision import ( # noqa: F401 + BinaryRecallAtFixedPrecision, + MulticlassRecallAtFixedPrecision, + MultilabelRecallAtFixedPrecision, +) +from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC # noqa: F401 from torchmetrics.classification.specificity import ( # noqa: F401 BinarySpecificity, MulticlassSpecificity, diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 4a754c7c849..7a0807de349 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -11,12 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.auroc import _auroc_compute, _auroc_update +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.auroc import ( + _auroc_compute, + _auroc_update, + _binary_auroc_arg_validation, + _binary_auroc_compute, + _multiclass_auroc_arg_validation, + _multiclass_auroc_compute, + _multilabel_auroc_arg_validation, + _multilabel_auroc_compute, +) from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -24,6 +39,292 @@ from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 +class BinaryAUROC(BinaryPrecisionRecallCurve): + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + A single scalar with the auroc score + + Example: + >>> from torchmetrics import BinaryAUROC + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryAUROC(thresholds=None) + >>> metric(preds, target) + tensor(0.5000) + >>> metric = BinaryAUROC(thresholds=5) + >>> metric(preds, target) + tensor(0.5000) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs) + if validate_args: + _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index) + self.max_fpr = max_fpr + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_auroc_compute(state, self.thresholds, self.max_fpr) + + +class MulticlassAUROC(MulticlassPrecisionRecallCurve): + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics import MulticlassAUROC + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassAUROC(num_classes=5, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.5333) + >>> metric = MulticlassAUROC(num_classes=5, average=None, thresholds=None) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + >>> metric = MulticlassAUROC(num_classes=5, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.5333) + >>> metric = MulticlassAUROC(num_classes=5, average=None, thresholds=5) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds) + + +class MultilabelAUROC(MultilabelPrecisionRecallCurve): + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics import MultilabelAUROC + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.6528) + >>> metric = MultilabelAUROC(num_labels=3, average=None, thresholds=None) + >>> metric(preds, target) + tensor([0.6250, 0.5000, 0.8333]) + >>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.6528) + >>> metric = MultilabelAUROC(num_labels=3, average=None, thresholds=5) + >>> metric(preds, target) + tensor([0.6250, 0.5000, 0.8333]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index) + + +# -------------------------- Old stuff -------------------------- + + class AUROC(Metric): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). Works for both binary, multilabel and multiclass problems. In the case of diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py new file mode 100644 index 00000000000..9fa309f8d16 --- /dev/null +++ b/src/torchmetrics/classification/average_precision.py @@ -0,0 +1,433 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.average_precision import ( + _average_precision_compute, + _average_precision_update, + _binary_average_precision_compute, + _multiclass_average_precision_arg_validation, + _multiclass_average_precision_compute, + _multilabel_average_precision_arg_validation, + _multilabel_average_precision_compute, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat + + +class BinaryAveragePrecision(BinaryPrecisionRecallCurve): + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + A single scalar with the average precision score + + Example: + >>> from torchmetrics import BinaryAveragePrecision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryAveragePrecision(thresholds=None) + >>> metric(preds, target) + tensor(0.5833) + >>> metric = BinaryAveragePrecision(thresholds=5) + >>> metric(preds, target) + tensor(0.6667) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_average_precision_compute(state, self.thresholds) + + +class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics import MulticlassAveragePrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.6250) + >>> metric = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=None) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) + >>> metric = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.5000) + >>> metric = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=5) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds) + + +class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics import MultilabelAveragePrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.7500) + >>> metric = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=None) + >>> metric(preds, target) + tensor([0.7500, 0.5833, 0.9167]) + >>> metric = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.7778) + >>> metric = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=5) + >>> metric(preds, target) + tensor([0.7500, 0.6667, 0.9167]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_average_precision_compute( + state, self.num_labels, self.average, self.thresholds, self.ignore_index + ) + + +# -------------------------- Old stuff -------------------------- + + +class AveragePrecision(Metric): + """Computes the average precision score, which summarises the precision recall curve into one number. Works for + both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- + vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translated to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range ``[0, num_classes-1]`` + average: + defines the reduction that is applied in the case of multiclass and multilabel input. + Should be one of the following: + + - ``'macro'`` [default]: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be + used with multiclass input. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support. + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (binary case): + >>> from torchmetrics import AveragePrecision + >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> average_precision = AveragePrecision(pos_label=1) + >>> average_precision(pred, target) + tensor(1.) + + Example (multiclass case): + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision = AveragePrecision(num_classes=5, average=None) + >>> average_precision(pred, target) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + preds: List[Tensor] + target: List[Tensor] + + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = "macro", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.num_classes = num_classes + self.pos_label = pos_label + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") + self.average = average + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + rank_zero_warn( + "Metric `AveragePrecision` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint." + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _average_precision_update( + preds, target, self.num_classes, self.pos_label, self.average + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[Tensor, List[Tensor]]: + """Compute the average precision score. + + Returns: + tensor with average precision. If multiclass return list of such tensors, one for each class + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + if not self.num_classes: + raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") + return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average) diff --git a/src/torchmetrics/classification/avg_precision.py b/src/torchmetrics/classification/avg_precision.py deleted file mode 100644 index 6cf94d13cd4..00000000000 --- a/src/torchmetrics/classification/avg_precision.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Optional, Union - -import torch -from torch import Tensor - -from torchmetrics.functional.classification.average_precision import ( - _average_precision_compute, - _average_precision_update, -) -from torchmetrics.metric import Metric -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import dim_zero_cat - - -class AveragePrecision(Metric): - """Computes the average precision score, which summarises the precision recall curve into one number. Works for - both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- - vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translated to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - average: - defines the reduction that is applied in the case of multiclass and multilabel input. - Should be one of the following: - - - ``'macro'`` [default]: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be - used with multiclass input. - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support. - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example (binary case): - >>> from torchmetrics import AveragePrecision - >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision = AveragePrecision(pos_label=1) - >>> average_precision(pred, target) - tensor(1.) - - Example (multiclass case): - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision = AveragePrecision(num_classes=5, average=None) - >>> average_precision(pred, target) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - """ - - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] - - def __init__( - self, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = "macro", - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.num_classes = num_classes - self.pos_label = pos_label - allowed_average = ("micro", "macro", "weighted", "none", None) - if average not in allowed_average: - raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") - self.average = average - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") - - rank_zero_warn( - "Metric `AveragePrecision` will save all targets and predictions in buffer." - " For large datasets this may lead to large memory footprint." - ) - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _average_precision_update( - preds, target, self.num_classes, self.pos_label, self.average - ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute(self) -> Union[Tensor, List[Tensor]]: - """Compute the average precision score. - - Returns: - tensor with average precision. If multiclass return list of such tensors, one for each class - """ - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - if not self.num_classes: - raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") - return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index ee4e29aecbc..eeaf8bbc27b 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -17,6 +17,22 @@ from torch import Tensor from torchmetrics.functional.classification.precision_recall_curve import ( + _adjust_threshold_arg, + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, _precision_recall_curve_compute, _precision_recall_curve_update, ) @@ -25,6 +41,379 @@ from torchmetrics.utilities.data import dim_zero_cat +class BinaryPrecisionRecallCurve(Metric): + r""" + Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - precision: an 1d tensor of size (n_thresholds+1, ) with precision values + - recall: an 1d tensor of size (n_thresholds+1, ) with recall values + - thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values + + Example: + >>> from torchmetrics import BinaryPrecisionRecallCurve + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryPrecisionRecallCurve(thresholds=None) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([0.5000, 0.7000, 0.8000])) + >>> metric = BinaryPrecisionRecallCurve(thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]), + tensor([1., 1., 1., 0., 0., 0.]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + + self.ignore_index = ignore_index + self.validate_args = validate_args + + if thresholds is None: + self.thresholds = thresholds + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.register_buffer("thresholds", _adjust_threshold_arg(thresholds)) + self.add_state("confmat", default=torch.zeros(thresholds, 2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index) + preds, target, _ = _binary_precision_recall_curve_format(preds, target, self.thresholds, self.ignore_index) + state = _binary_precision_recall_curve_update(preds, target, self.thresholds) + if isinstance(state, Tensor): + self.confmat += state + else: + self.preds.append(state[0]) + self.target.append(state[1]) + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_precision_recall_curve_compute(state, self.thresholds) + + +class MulticlassPrecisionRecallCurve(Metric): + r""" + Computes the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics import MulticlassPrecisionRecallCurve + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=None) + >>> precision, recall, thresholds = metric(preds, target) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)] + >>> metric = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[1., 1., 1., 1., 0., 0.], + [1., 1., 1., 1., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0.]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + + self.num_classes = num_classes + self.ignore_index = ignore_index + self.validate_args = validate_args + + if thresholds is None: + self.thresholds = thresholds + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.register_buffer("thresholds", _adjust_threshold_arg(thresholds)) + self.add_state( + "confmat", default=torch.zeros(thresholds, num_classes, 2, 2, dtype=torch.long), dist_reduce_fx="sum" + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target, _ = _multiclass_precision_recall_curve_format( + preds, target, self.num_classes, self.thresholds, self.ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, self.num_classes, self.thresholds) + if isinstance(state, Tensor): + self.confmat += state + else: + self.preds.append(state[0]) + self.target.append(state[1]) + + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds) + + +class MultilabelPrecisionRecallCurve(Metric): + r""" + Computes the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics import MultilabelPrecisionRecallCurve + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=None) + >>> precision, recall, thresholds = metric(preds, target) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([0.7500, 1.0000, 1.0000, 1.0000])] + >>> recall # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([1.0000, 0.6667, 0.3333, 0.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]), + tensor([0.0500, 0.3500, 0.7500])] + >>> metric = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000], + [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]), + tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000], + [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000], + [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + + self.num_labels = num_labels + self.ignore_index = ignore_index + self.validate_args = validate_args + + if thresholds is None: + self.thresholds = thresholds + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.register_buffer("thresholds", _adjust_threshold_arg(thresholds)) + self.add_state( + "confmat", default=torch.zeros(thresholds, num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum" + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_precision_recall_curve_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target, _ = _multilabel_precision_recall_curve_format( + preds, target, self.num_labels, self.thresholds, self.ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, self.num_labels, self.thresholds) + if isinstance(state, Tensor): + self.confmat += state + else: + self.preds.append(state[0]) + self.target.append(state[1]) + + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) + + +# -------------------------- Old stuff -------------------------- + + class PrecisionRecallCurve(Metric): """Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py new file mode 100644 index 00000000000..81a4a534a8e --- /dev/null +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -0,0 +1,299 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.recall_at_fixed_precision import ( + _binary_recall_at_fixed_precision_arg_validation, + _binary_recall_at_fixed_precision_compute, + _multiclass_recall_at_fixed_precision_arg_compute, + _multiclass_recall_at_fixed_precision_arg_validation, + _multilabel_recall_at_fixed_precision_arg_compute, + _multilabel_recall_at_fixed_precision_arg_validation, +) +from torchmetrics.utilities.data import dim_zero_cat + + +class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - recall: an scalar tensor with the maximum recall for the given precision level + - threshold: an scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics import BinaryRecallAtFixedPrecision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=None) + >>> metric(preds, target) + (tensor(1.), tensor(0.5000)) + >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=5) + >>> metric(preds, target) + (tensor(1.), tensor(0.5000)) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(thresholds, ignore_index, validate_args=False, **kwargs) + if validate_args: + _binary_recall_at_fixed_precision_arg_validation(min_precision, thresholds, ignore_index) + self.validate_args = validate_args + self.min_precision = min_precision + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_recall_at_fixed_precision_compute(state, self.thresholds, self.min_precision) + + +class MulticlassRecallAtFixedPrecision(MulticlassPrecisionRecallCurve): + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics import MulticlassRecallAtFixedPrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + >>> metric = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_recall_at_fixed_precision_arg_validation(num_classes, min_precision, thresholds, ignore_index) + self.validate_args = validate_args + self.min_precision = min_precision + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_recall_at_fixed_precision_arg_compute( + state, self.num_classes, self.thresholds, self.min_precision + ) + + +class MultilabelRecallAtFixedPrecision(MultilabelPrecisionRecallCurve): + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifing the number of labels + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics import MultilabelRecallAtFixedPrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500])) + >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_recall_at_fixed_precision_arg_validation(num_labels, min_precision, thresholds, ignore_index) + self.validate_args = validate_args + self.min_precision = min_precision + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_recall_at_fixed_precision_arg_compute( + state, self.num_labels, self.thresholds, self.ignore_index, self.min_precision + ) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 7682dd758ac..b8f5f96d336 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -16,12 +16,300 @@ import torch from torch import Tensor -from torchmetrics.functional.classification.roc import _roc_compute, _roc_update +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, + _roc_compute, + _roc_update, +) from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +class BinaryROC(BinaryPrecisionRecallCurve): + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - fpr: an 1d tensor of size (n_thresholds+1, ) with false positive rate values + - tpr: an 1d tensor of size (n_thresholds+1, ) with true positive rate values + - thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values + + Example: + >>> from torchmetrics import BinaryROC + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryROC(thresholds=None) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000])) + >>> metric = BinaryROC(thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 1., 1., 1.]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_roc_compute(state, self.thresholds) + + +class MulticlassROC(MulticlassPrecisionRecallCurve): + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics import MulticlassROC + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassROC(num_classes=5, thresholds=None) + >>> fpr, tpr, thresholds = metric(preds, target) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), + tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])] + >>> metric = MulticlassROC(num_classes=5, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0., 1., 1., 1., 1.], + [0., 1., 1., 1., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0.]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_roc_compute(state, self.num_classes, self.thresholds) + + +class MultilabelROC(MultilabelPrecisionRecallCurve): + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics import MultilabelROC + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelROC(num_labels=3, thresholds=None) + >>> fpr, tpr, thresholds = metric(preds, target) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 0., 1.])] + >>> tpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([0.0000, 0.3333, 0.6667, 1.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.4500, 0.0500]), + tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]), + tensor([1.0000, 0.7500, 0.3500, 0.0500])] + >>> metric = MultilabelROC(num_labels=3, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 1.0000, 1.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) + + +# -------------------------- Old stuff -------------------------- + + class ROC(Metric): """Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 74d6a4c9648..3d1e259301c 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -21,8 +21,13 @@ multilabel_accuracy, ) from torchmetrics.functional.classification.auc import auc -from torchmetrics.functional.classification.auroc import auroc -from torchmetrics.functional.classification.average_precision import average_precision +from torchmetrics.functional.classification.auroc import auroc, binary_auroc, multiclass_auroc, multilabel_auroc +from torchmetrics.functional.classification.average_precision import ( + average_precision, + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) from torchmetrics.functional.classification.calibration_error import calibration_error from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa from torchmetrics.functional.classification.confusion_matrix import ( @@ -74,7 +79,12 @@ precision_recall, recall, ) -from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve +from torchmetrics.functional.classification.precision_recall_curve import ( + binary_precision_recall_curve, + multiclass_precision_recall_curve, + multilabel_precision_recall_curve, + precision_recall_curve, +) from torchmetrics.functional.classification.ranking import ( coverage_error, label_ranking_average_precision, @@ -83,7 +93,12 @@ multilabel_ranking_average_precision, multilabel_ranking_loss, ) -from torchmetrics.functional.classification.roc import roc +from torchmetrics.functional.classification.recall_at_fixed_precision import ( + binary_recall_at_fixed_precision, + multiclass_recall_at_fixed_precision, + multilabel_recall_at_fixed_precision, +) +from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc, roc from torchmetrics.functional.classification.specificity import ( binary_specificity, multiclass_specificity, @@ -269,4 +284,19 @@ "multiclass_recall", "multilabel_recall", "multilabel_exact_match", + "binary_auroc", + "multiclass_auroc", + "multilabel_auroc", + "binary_average_precision", + "multiclass_average_precision", + "multilabel_average_precision", + "binary_precision_recall_curve", + "multiclass_precision_recall_curve", + "multilabel_precision_recall_curve", + "binary_recall_at_fixed_precision", + "multiclass_recall_at_fixed_precision", + "multilabel_recall_at_fixed_precision", + "binary_roc", + "multiclass_roc", + "multilabel_roc", ] diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 313fc5a29a8..371feb55226 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -18,8 +18,18 @@ multilabel_accuracy, ) from torchmetrics.functional.classification.auc import auc # noqa: F401 -from torchmetrics.functional.classification.auroc import auroc # noqa: F401 -from torchmetrics.functional.classification.average_precision import average_precision # noqa: F401 +from torchmetrics.functional.classification.auroc import ( # noqa: F401 + auroc, + binary_auroc, + multiclass_auroc, + multilabel_auroc, +) +from torchmetrics.functional.classification.average_precision import ( # noqa: F401 + average_precision, + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401 from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 from torchmetrics.functional.classification.confusion_matrix import ( # noqa: F401 @@ -61,7 +71,12 @@ precision_recall, recall, ) -from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 +from torchmetrics.functional.classification.precision_recall_curve import ( # noqa: F401 + binary_precision_recall_curve, + multiclass_precision_recall_curve, + multilabel_precision_recall_curve, + precision_recall_curve, +) from torchmetrics.functional.classification.ranking import ( # noqa: F401 coverage_error, label_ranking_average_precision, @@ -70,7 +85,12 @@ multilabel_ranking_average_precision, multilabel_ranking_loss, ) -from torchmetrics.functional.classification.roc import roc # noqa: F401 +from torchmetrics.functional.classification.recall_at_fixed_precision import ( # noqa: F401 + binary_recall_at_fixed_precision, + multiclass_recall_at_fixed_precision, + multilabel_recall_at_fixed_precision, +) +from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc, roc # noqa: F401 from torchmetrics.functional.classification.specificity import ( # noqa: F401 binary_specificity, multiclass_specificity, diff --git a/src/torchmetrics/functional/classification/auc.py b/src/torchmetrics/functional/classification/auc.py index 7c439cdefe0..bb4faaa6a9b 100644 --- a/src/torchmetrics/functional/classification/auc.py +++ b/src/torchmetrics/functional/classification/auc.py @@ -43,7 +43,7 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: return x, y -def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float) -> Tensor: +def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float, axis: int = -1) -> Tensor: """Computes area under the curve using the trapezoidal rule. Assumes increasing or decreasing order of `x`. Args: @@ -60,7 +60,7 @@ def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float) -> Tensor """ with torch.no_grad(): - auc_: Tensor = torch.trapz(y, x) * direction + auc_: Tensor = torch.trapz(y, x, dim=axis) * direction return auc_ diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index ecc1c7eac72..d435077e6ae 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -12,17 +12,425 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union import torch from torch import Tensor, tensor +from typing_extensions import Literal from torchmetrics.functional.classification.auc import _auc_compute_without_check -from torchmetrics.functional.classification.roc import roc +from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, +) +from torchmetrics.functional.classification.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, + roc, +) from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 +from torchmetrics.utilities.prints import rank_zero_warn + + +def _reduce_auroc( + fpr: Union[Tensor, List[Tensor]], + tpr: Union[Tensor, List[Tensor]], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + weights: Optional[Tensor] = None, +) -> Tensor: + """Utility function for reducing multiple average precision score into one number.""" + res = [] + if isinstance(fpr, Tensor): + res = _auc_compute_without_check(fpr, tpr, 1.0, axis=1) + else: + res = [_auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr)] + res = torch.stack(res) + if average is None or average == "none": + return res + if torch.isnan(res).any(): + rank_zero_warn( + f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average", + UserWarning, + ) + idx = ~torch.isnan(res) + if average == "macro": + return res[idx].mean() + elif average == "weighted" and weights is not None: + weights = _safe_divide(weights[idx], weights[idx].sum()) + return (res[idx] * weights).sum() + else: + raise ValueError("Received an incompatible combinations of inputs to make reduction.") + + +def _binary_auroc_arg_validation( + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + if max_fpr is not None: + if not isinstance(max_fpr, float) and 0 < max_fpr <= 1: + raise ValueError(f"Arguments `max_fpr` should be a float in range (0, 1], but got: {max_fpr}") + if _TORCH_LOWER_1_6: + raise RuntimeError( + "`max_fpr` argument requires `torch.bucketize` which" " is not available below PyTorch version 1.6" + ) + + +def _binary_auroc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + max_fpr: Optional[float] = None, + pos_label: int = 1, +) -> Tuple[Tensor, Tensor, Tensor]: + fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) + if max_fpr is None or max_fpr == 1: + return _auc_compute_without_check(fpr, tpr, 1.0) + + _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device + max_area: Tensor = tensor(max_fpr, device=_device) + # Add a single point at max_fpr and interpolate its tpr value + stop = torch.bucketize(max_area, fpr, out_int32=True, right=True) + weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) + interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight) + tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) + fpr = torch.cat([fpr[:stop], max_area.view(1)]) + + # Compute partial AUC + partial_auc = _auc_compute_without_check(fpr, tpr, 1.0) + + # McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal + min_area: Tensor = 0.5 * max_area**2 + return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) + + +def binary_auroc( + preds: Tensor, + target: Tensor, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A single scalar with the auroc score + + Example: + >>> from torchmetrics.functional import binary_auroc + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_auroc(preds, target, thresholds=None) + tensor(0.5000) + >>> binary_auroc(preds, target, thresholds=5) + tensor(0.5000) + """ + if validate_args: + _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_auroc_compute(state, thresholds, max_fpr) + + +def _multiclass_auroc_arg_validation( + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + allowed_average = ("macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multiclass_auroc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Tensor] = None, +) -> Tensor: + fpr, tpr, _ = _multiclass_roc_compute(state, num_classes, thresholds) + return _reduce_auroc( + fpr, + tpr, + average, + weights=_bincount(state[1], minlength=num_classes).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multiclass_auroc( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional import multiclass_auroc + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=None) + tensor(0.5333) + >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=None) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=5) + tensor(0.5333) + >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=5) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + + """ + if validate_args: + _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_auroc_compute(state, num_classes, average, thresholds) + + +def _multilabel_auroc_arg_validation( + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multilabel_auroc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if average == "micro": + if isinstance(state, Tensor) and thresholds is not None: + return _binary_auroc_compute(state.sum(1), thresholds, max_fpr=None) + else: + preds = state[0].flatten() + target = state[1].flatten() + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + return _binary_auroc_compute([preds, target], thresholds, max_fpr=None) + + else: + fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) + return _reduce_auroc( + fpr, + tpr, + average, + weights=(state[1] == 1).sum(dim=0).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multilabel_auroc( + preds: Tensor, + target: Tensor, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional import multilabel_auroc + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=None) + tensor(0.6528) + >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=None) + tensor([0.6250, 0.5000, 0.8333]) + >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=5) + tensor(0.6528) + >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=5) + tensor([0.6250, 0.5000, 0.8333]) + """ + if validate_args: + _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_auroc_compute(state, num_labels, average, thresholds, ignore_index) + + +# -------------------------- Old stuff -------------------------- def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, DataType]: diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 53df4956889..3e8ce01d2f6 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -16,12 +16,393 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, _precision_recall_curve_compute, _precision_recall_curve_update, ) +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.prints import rank_zero_warn + + +def _reduce_average_precision( + precision: Union[Tensor, List[Tensor]], + recall: Union[Tensor, List[Tensor]], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + weights: Optional[Tensor] = None, +) -> Tensor: + """Utility function for reducing multiple average precision score into one number.""" + res = [] + if isinstance(precision, Tensor) and isinstance(recall, Tensor): + res = -torch.sum((recall[:, 1:] - recall[:, :-1]) * precision[:, :-1], 1) + else: + for p, r in zip(precision, recall): + res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) + res = torch.stack(res) + if average is None or average == "none": + return res + if torch.isnan(res).any(): + rank_zero_warn( + f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average", + UserWarning, + ) + idx = ~torch.isnan(res) + if average == "macro": + return res[idx].mean() + elif average == "weighted" and weights is not None: + weights = _safe_divide(weights[idx], weights[idx].sum()) + return (res[idx] * weights).sum() + else: + raise ValueError("Received an incompatible combinations of inputs to make reduction.") + + +def _binary_average_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], +) -> Tensor: + precision, recall, _ = _binary_precision_recall_curve_compute(state, thresholds) + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + +def binary_average_precision( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A single scalar with the average precision score + + Example: + >>> from torchmetrics.functional import binary_average_precision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_average_precision(preds, target, thresholds=None) + tensor(0.5833) + >>> binary_average_precision(preds, target, thresholds=5) + tensor(0.6667) + """ + + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_average_precision_compute(state, thresholds) + + +def _multiclass_average_precision_arg_validation( + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + allowed_average = ("macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multiclass_average_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Tensor] = None, +) -> Tensor: + precision, recall, _ = _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + return _reduce_average_precision( + precision, + recall, + average, + weights=_bincount(state[1], minlength=num_classes).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multiclass_average_precision( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional import multiclass_average_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=None) + tensor(0.6250) + >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=None) + tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) + >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=5) + tensor(0.5000) + >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=5) + tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000]) + + """ + if validate_args: + _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_average_precision_compute(state, num_classes, average, thresholds) + + +def _multilabel_average_precision_arg_validation( + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multilabel_average_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if average == "micro": + if isinstance(state, Tensor) and thresholds is not None: + state = state.sum(1) + else: + preds, target = state[0].flatten(), state[1].flatten() + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + state = [preds, target] + return _binary_average_precision_compute(state, thresholds) + else: + precision, recall, _ = _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index) + return _reduce_average_precision( + precision, + recall, + average, + weights=(state[1] == 1).sum(dim=0).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multilabel_average_precision( + preds: Tensor, + target: Tensor, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional import multilabel_average_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=None) + tensor(0.7500) + >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=None) + tensor([0.7500, 0.5833, 0.9167]) + >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=5) + tensor(0.7778) + >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=5) + tensor([0.7500, 0.6667, 0.9167]) + """ + if validate_args: + _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_average_precision_compute(state, num_labels, average, thresholds, ignore_index) + + +# -------------------------- Old stuff -------------------------- def _average_precision_update( diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 67dddac607b..a49521cb685 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import List, Optional, Sequence, Tuple, Union import torch @@ -18,6 +19,9 @@ from torch.nn import functional as F from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.data import _bincount def _binary_clf_curve( @@ -26,39 +30,746 @@ def _binary_clf_curve( sample_weights: Optional[Sequence] = None, pos_label: int = 1, ) -> Tuple[Tensor, Tensor, Tensor]: - """adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py.""" - if sample_weights is not None and not isinstance(sample_weights, Tensor): - sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float) + """Calculates the tps and false positives for all unique thresholds in the preds tensor. Adapted from + https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_ranking.py. + + Args: + preds: 1d tensor with predictions + target: 1d tensor with true values + sample_weights: a 1d tensor with a weight per sample + pos_label: interger determining what the positive class in target tensor is + + Returns: + fps: 1d tensor with false positives for different thresholds + tps: 1d tensor with true positives for different thresholds + thresholds: the unique thresholds use for calculating fps and tps + """ + with torch.no_grad(): + if sample_weights is not None and not isinstance(sample_weights, Tensor): + sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float) + + # remove class dimension if necessary + if preds.ndim > target.ndim: + preds = preds[:, 0] + desc_score_indices = torch.argsort(preds, descending=True) + + preds = preds[desc_score_indices] + target = target[desc_score_indices] + + if sample_weights is not None: + weight = sample_weights[desc_score_indices] + else: + weight = 1.0 + + # pred typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] + threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1) + target = (target == pos_label).to(torch.long) + tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] + + if sample_weights is not None: + # express fps as a cumsum to ensure fps is increasing even in + # the presence of floating point errors + fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] + else: + fps = 1 + threshold_idxs - tps + + return fps, tps, preds[threshold_idxs] + + +def _adjust_threshold_arg( + thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None +) -> Optional[Tensor]: + """Utility function for converting the threshold arg for list and int to tensor format.""" + if isinstance(thresholds, int): + thresholds = torch.linspace(0, 1, thresholds, device=device) + if isinstance(thresholds, list): + thresholds = torch.tensor(thresholds, device=device) + return thresholds + + +def _binary_precision_recall_curve_arg_validation( + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int + - ``ignore_index`` has to be None or int + """ + if thresholds is not None and not isinstance(thresholds, (list, int, Tensor)): + raise ValueError( + "Expected argument `thresholds` to either be an integer, list of floats or" + f" tensor of floats, but got {thresholds}" + ) + else: + if isinstance(thresholds, int) and thresholds < 2: + raise ValueError( + f"If argument `thresholds` is an integer, expected it to be larger than 1, but got {thresholds}" + ) + if isinstance(thresholds, list) and not all(isinstance(t, float) and 0 <= t <= 1 for t in thresholds): + raise ValueError( + "If argument `thresholds` is a list, expected all elements to be floats in the [0,1] range," + f" but got {thresholds}" + ) + if isinstance(thresholds, Tensor) and not thresholds.ndim == 1: + raise ValueError("If argument `thresholds` is an tensor, expected the tensor to be 1d") + + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_precision_recall_curve_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - all values in target tensor that are not ignored have to be in {0, 1} + - that the pred tensor is floating point + """ + _check_same_shape(preds, target) + + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be an floating tensor with probability/logit scores," + f" but got tensor with dtype {preds.dtype}" + ) + + # Check that target only contains {0,1} values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + +def _binary_precision_recall_curve_format( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies sigmoid if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + """ + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + + thresholds = _adjust_threshold_arg(thresholds, preds.device) + return preds, target, thresholds + + +def _binary_precision_recall_curve_update( + preds: Tensor, + target: Tensor, + thresholds: Optional[Tensor], +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the state to calculate the pr-curve with. + + If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi + threshold confusion matrix. + """ + if thresholds is None: + return preds, target + len_t = len(thresholds) + preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds + unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device) + bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t) + return bins.reshape(len_t, 2, 2) + + +def _binary_precision_recall_curve_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + pos_label: int = 1, +) -> Tuple[Tensor, Tensor, Tensor]: + """Computes the final pr-curve. + + If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is + original input, then we dynamically compute the binary classification curve. + """ + if isinstance(state, Tensor): + tps = state[:, 1, 1] + fps = state[:, 0, 1] + fns = state[:, 1, 0] + precision = _safe_divide(tps, tps + fps) + recall = _safe_divide(tps, tps + fns) + precision = torch.cat([precision, torch.ones(1, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([recall, torch.zeros(1, dtype=recall.dtype, device=recall.device)]) + return precision, recall, thresholds + else: + fps, tps, thresholds = _binary_clf_curve(state[0], state[1], pos_label=pos_label) + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) + thresholds = reversed(thresholds[sl]).detach().clone() # type: ignore + + return precision, recall, thresholds + + +def binary_precision_recall_curve( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - precision: an 1d tensor of size (n_thresholds+1, ) with precision values + - recall: an 1d tensor of size (n_thresholds+1, ) with recall values + - thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values + + Example: + >>> from torchmetrics.functional import binary_precision_recall_curve + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_precision_recall_curve(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([0.5000, 0.7000, 0.8000])) + >>> binary_precision_recall_curve(preds, target, thresholds=5) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]), + tensor([1., 1., 1., 0., 0., 0.]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_precision_recall_curve_compute(state, thresholds) + + +def _multiclass_precision_recall_curve_arg_validation( + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be an int larger than 1 + - ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int + - ``ignore_index`` has to be None or int + """ + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + + +def _multiclass_precision_recall_curve_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - target should have one more dimension than preds and all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - all values in target tensor that are not ignored have to be in {0, 1} + """ + if not preds.ndim == target.ndim + 1: + raise ValueError( + f"Expected `preds` to have one more dimension than `target` but got {preds.ndim} and {target.ndim}" + ) + if not preds.is_floating_point(): + raise ValueError(f"Expected `preds` to be a float tensor, but got {preds.dtype}") + if preds.shape[1] != num_classes: + raise ValueError( + "Expected `preds.shape[1]` to be equal to the number of classes but" + f" got {preds.shape[1]} and {num_classes}." + ) + if preds.shape[0] != target.shape[0] or preds.shape[2:] != target.shape[1:]: + raise ValueError( + "Expected the shape of `preds` should be (N, C, ...) and the shape of `target` should be (N, ...)" + f" but got {preds.shape} and {target.shape}" + ) + + num_unique_values = len(torch.unique(target)) + if ignore_index is None: + check = num_unique_values > num_classes + else: + check = num_unique_values > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found " + f"{num_unique_values} in `target`." + ) + + +def _multiclass_precision_recall_curve_format( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies softmax if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + """ + preds = preds.transpose(0, 1).reshape(num_classes, -1).T + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.softmax(1) - # remove class dimension if necessary - if preds.ndim > target.ndim: - preds = preds[:, 0] - desc_score_indices = torch.argsort(preds, descending=True) + thresholds = _adjust_threshold_arg(thresholds, preds.device) + return preds, target, thresholds - preds = preds[desc_score_indices] - target = target[desc_score_indices] - if sample_weights is not None: - weight = sample_weights[desc_score_indices] +def _multiclass_precision_recall_curve_update( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Tensor], +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the state to calculate the pr-curve with. + + If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi + threshold confusion matrix. + """ + if thresholds is None: + return preds, target + len_t = len(thresholds) + # num_samples x num_classes x num_thresholds + preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long() + target_t = torch.nn.functional.one_hot(target, num_classes=num_classes) + unique_mapping = preds_t + 2 * target_t.unsqueeze(-1) + unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1) + unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device) + bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t) + return bins.reshape(len_t, num_classes, 2, 2) + + +def _multiclass_precision_recall_curve_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Computes the final pr-curve. + + If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is + original input, then we dynamically compute the binary classification curve in an iterative way. + """ + if isinstance(state, Tensor): + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + precision = _safe_divide(tps, tps + fps) + recall = _safe_divide(tps, tps + fns) + precision = torch.cat([precision, torch.ones(1, num_classes, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([recall, torch.zeros(1, num_classes, dtype=recall.dtype, device=recall.device)]) + return precision.T, recall.T, thresholds else: - weight = 1.0 - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1) - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weights is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] + precision, recall, thresholds = [], [], [] + for i in range(num_classes): + res = _binary_precision_recall_curve_compute([state[0][:, i], state[1]], thresholds=None, pos_label=i) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds + + +def multiclass_precision_recall_curve( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics.functional import multiclass_precision_recall_curve + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> precision, recall, thresholds = multiclass_precision_recall_curve( + ... preds, target, num_classes=5, thresholds=None + ... ) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + >>> multiclass_precision_recall_curve( + ... preds, target, num_classes=5, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[1., 1., 1., 1., 0., 0.], + [1., 1., 1., 1., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0.]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + if validate_args: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + + +def _multilabel_precision_recall_curve_arg_validation( + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_labels`` has to be an int larger than 1 + - ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int + - ``ignore_index`` has to be None or int + """ + _multiclass_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + + +def _multilabel_precision_recall_curve_tensor_validation( + preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - preds.shape[1] is equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - that the pred tensor is floating point + """ + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) + + +def _multilabel_precision_recall_curve_format( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Convert all input to the right format. + + - flattens additional dimensions + - Mask all datapoints that should be ignored with negative values + - Applies sigmoid if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + """ + preds = preds.transpose(0, 1).reshape(num_labels, -1).T + target = target.transpose(0, 1).reshape(num_labels, -1).T + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + + thresholds = _adjust_threshold_arg(thresholds, preds.device) + if ignore_index is not None and thresholds is not None: + preds = preds.clone() + target = target.clone() + # Make sure that when we map, it will always result in a negative number that we can filter away + idx = target == ignore_index + preds[idx] = -4 * num_labels * (len(thresholds) if thresholds is not None else 1) + target[idx] = -4 * num_labels * (len(thresholds) if thresholds is not None else 1) + + return preds, target, thresholds + + +def _multilabel_precision_recall_curve_update( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Tensor], +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the state to calculate the pr-curve with. + + If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi + threshold confusion matrix. + """ + if thresholds is None: + return preds, target + len_t = len(thresholds) + # num_samples x num_labels x num_thresholds + preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long() + unique_mapping = preds_t + 2 * target.unsqueeze(-1) + unique_mapping += 4 * torch.arange(num_labels, device=preds.device).unsqueeze(0).unsqueeze(-1) + unique_mapping += 4 * num_labels * torch.arange(len_t, device=preds.device) + unique_mapping = unique_mapping[unique_mapping >= 0] + bins = _bincount(unique_mapping, minlength=4 * num_labels * len_t) + return bins.reshape(len_t, num_labels, 2, 2) + + +def _multilabel_precision_recall_curve_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Computes the final pr-curve. + + If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is + original input, then we dynamically compute the binary classification curve in an iterative way. + """ + if isinstance(state, Tensor): + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + precision = _safe_divide(tps, tps + fps) + recall = _safe_divide(tps, tps + fns) + precision = torch.cat([precision, torch.ones(1, num_labels, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([recall, torch.zeros(1, num_labels, dtype=recall.dtype, device=recall.device)]) + return precision.T, recall.T, thresholds else: - fps = 1 + threshold_idxs - tps + precision, recall, thresholds = [], [], [] + for i in range(num_labels): + preds = state[0][:, i] + target = state[1][:, i] + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + res = _binary_precision_recall_curve_compute([preds, target], thresholds=None, pos_label=1) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds + + +def multilabel_precision_recall_curve( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics.functional import multilabel_precision_recall_curve + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> precision, recall, thresholds = multilabel_precision_recall_curve( + ... preds, target, num_labels=3, thresholds=None + ... ) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([0.7500, 1.0000, 1.0000, 1.0000])] + >>> recall # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([1.0000, 0.6667, 0.3333, 0.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]), + tensor([0.0500, 0.3500, 0.7500])] + >>> multilabel_precision_recall_curve( + ... preds, target, num_labels=3, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000], + [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]), + tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000], + [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000], + [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + if validate_args: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index) + - return fps, tps, preds[threshold_idxs] +# -------------------------- Old stuff -------------------------- def _precision_recall_curve_update( diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py new file mode 100644 index 00000000000..81dbdc694ee --- /dev/null +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -0,0 +1,366 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, +) + + +def _recall_at_precision( + precision: Tensor, + recall: Tensor, + thresholds: Tensor, + min_precision: float, +) -> Tuple[Tensor, Tensor]: + try: + max_recall, _, best_threshold = max( + (r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision + ) + + except ValueError: + max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) + best_threshold = torch.tensor(0) + + if max_recall == 0.0: + best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype) + + return max_recall, best_threshold + + +def _binary_recall_at_fixed_precision_arg_validation( + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + if not isinstance(min_precision, float) and not (0 <= min_precision <= 1): + raise ValueError( + f"Expected argument `min_precision` to be an float in the [0,1] range, but got {min_precision}" + ) + + +def _binary_recall_at_fixed_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + min_precision: float, + pos_label: int = 1, +) -> Tuple[Tensor, Tensor]: + precision, recall, thresholds = _binary_precision_recall_curve_compute(state, thresholds, pos_label) + return _recall_at_precision(precision, recall, thresholds, min_precision) + + +def binary_recall_at_fixed_precision( + preds: Tensor, + target: Tensor, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - recall: an scalar tensor with the maximum recall for the given precision level + - threshold: an scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics.functional import binary_recall_at_fixed_precision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=None) + (tensor(1.), tensor(0.5000)) + >>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=5) + (tensor(1.), tensor(0.5000)) + """ + if validate_args: + _binary_recall_at_fixed_precision_arg_validation(min_precision, thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_recall_at_fixed_precision_compute(state, thresholds, min_precision) + + +def _multiclass_recall_at_fixed_precision_arg_validation( + num_classes: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + if not isinstance(min_precision, float) and not (0 <= min_precision <= 1): + raise ValueError( + f"Expected argument `min_precision` to be an float in the [0,1] range, but got {min_precision}" + ) + + +def _multiclass_recall_at_fixed_precision_arg_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], + min_precision: float, +) -> Tuple[Tensor, Tensor]: + precision, recall, thresholds = _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + if isinstance(state, Tensor): + res = [_recall_at_precision(p, r, thresholds, min_precision) for p, r in zip(precision, recall)] + else: + res = [_recall_at_precision(p, r, t, min_precision) for p, r, t in zip(precision, recall, thresholds)] + recall = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return recall, thresholds + + +def multiclass_recall_at_fixed_precision( + preds: Tensor, + target: Tensor, + num_classes: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional import multiclass_recall_at_fixed_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=None) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + >>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=5) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + """ + if validate_args: + _multiclass_recall_at_fixed_precision_arg_validation(num_classes, min_precision, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_recall_at_fixed_precision_arg_compute(state, num_classes, thresholds, min_precision) + + +def _multilabel_recall_at_fixed_precision_arg_validation( + num_labels: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + if not isinstance(min_precision, float) and not (0 <= min_precision <= 1): + raise ValueError( + f"Expected argument `min_precision` to be an float in the [0,1] range, but got {min_precision}" + ) + + +def _multilabel_recall_at_fixed_precision_arg_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int], + min_precision: float, +) -> Tuple[Tensor, Tensor]: + precision, recall, thresholds = _multilabel_precision_recall_curve_compute( + state, num_labels, thresholds, ignore_index + ) + if isinstance(state, Tensor): + res = [_recall_at_precision(p, r, thresholds, min_precision) for p, r in zip(precision, recall)] + else: + res = [_recall_at_precision(p, r, t, min_precision) for p, r, t in zip(precision, recall, thresholds)] + recall = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return recall, thresholds + + +def multilabel_recall_at_fixed_precision( + preds: Tensor, + target: Tensor, + num_labels: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional import multilabel_recall_at_fixed_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=None) + (tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500])) + >>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=5) + (tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000])) + """ + if validate_args: + _multilabel_recall_at_fixed_precision_arg_validation(num_labels, min_precision, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_recall_at_fixed_precision_arg_compute(state, num_labels, thresholds, ignore_index, min_precision) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 0e5ce0b58f0..a9e6e1364d0 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -18,9 +18,412 @@ from torchmetrics.functional.classification.precision_recall_curve import ( _binary_clf_curve, + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, _precision_recall_curve_update, ) from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.compute import _safe_divide + + +def _binary_roc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + pos_label: int = 1, +) -> Tuple[Tensor, Tensor, Tensor]: + if isinstance(state, Tensor) and thresholds is not None: + tps = state[:, 1, 1] + fps = state[:, 0, 1] + fns = state[:, 1, 0] + tns = state[:, 0, 0] + tpr = _safe_divide(tps, tps + fns).flip(0) + fpr = _safe_divide(fps, fps + tns).flip(0) + thresholds = thresholds.flip(0) + else: + fps, tps, thresholds = _binary_clf_curve(preds=state[0], target=state[1], pos_label=pos_label) + # Add an extra threshold position to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([torch.ones(1, dtype=thresholds.dtype, device=thresholds.device), thresholds]) + + if fps[-1] <= 0: + rank_zero_warn( + "No negative samples in targets, false positive value should be meaningless." + " Returning zero tensor in false positive score", + UserWarning, + ) + fpr = torch.zeros_like(thresholds) + else: + fpr = fps / fps[-1] + + if tps[-1] <= 0: + rank_zero_warn( + "No positive samples in targets, true positive value should be meaningless." + " Returning zero tensor in true positive score", + UserWarning, + ) + tpr = torch.zeros_like(thresholds) + else: + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + +def binary_roc( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - fpr: an 1d tensor of size (n_thresholds+1, ) with false positive rate values + - tpr: an 1d tensor of size (n_thresholds+1, ) with true positive rate values + - thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values + + Example: + >>> from torchmetrics.functional import binary_roc + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_roc(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000])) + >>> binary_roc(preds, target, thresholds=5) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 1., 1., 1.]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_roc_compute(state, thresholds) + + +def _multiclass_roc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if isinstance(state, Tensor) and thresholds is not None: + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + tns = state[:, :, 0, 0] + tpr = _safe_divide(tps, tps + fns).flip(0).T + fpr = _safe_divide(fps, fps + tns).flip(0).T + thresholds = thresholds.flip(0) + else: + fpr, tpr, thresholds = [], [], [] + for i in range(num_classes): + res = _binary_roc_compute([state[0][:, i], state[1]], thresholds=None, pos_label=i) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds + + +def multiclass_roc( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics.functional import multiclass_roc + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> fpr, tpr, thresholds = multiclass_roc( + ... preds, target, num_classes=5, thresholds=None + ... ) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), + tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])] + >>> multiclass_roc( + ... preds, target, num_classes=5, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0., 1., 1., 1., 1.], + [0., 1., 1., 1., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0.]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + if validate_args: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_roc_compute(state, num_classes, thresholds) + + +def _multilabel_roc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if isinstance(state, Tensor) and thresholds is not None: + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + tns = state[:, :, 0, 0] + tpr = _safe_divide(tps, tps + fns).flip(0).T + fpr = _safe_divide(fps, fps + tns).flip(0).T + thresholds = thresholds.flip(0) + else: + fpr, tpr, thresholds = [], [], [] + for i in range(num_labels): + preds = state[0][:, i] + target = state[1][:, i] + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + res = _binary_roc_compute([preds, target], thresholds=None, pos_label=1) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds + + +def multilabel_roc( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics.functional import multilabel_roc + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> fpr, tpr, thresholds = multilabel_roc( + ... preds, target, num_labels=3, thresholds=None + ... ) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 0., 1.])] + >>> tpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([0.0000, 0.3333, 0.6667, 1.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.4500, 0.0500]), + tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]), + tensor([1.0000, 0.7500, 0.3500, 0.0500])] + >>> multilabel_roc( + ... preds, target, num_labels=3, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 1.0000, 1.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + if validate_args: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) + + +# -------------------------- Old stuff -------------------------- def _roc_update( diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 691150f4639..5b33b649802 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -13,228 +13,561 @@ # limitations under the License. from functools import partial +import numpy as np import pytest import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import roc_auc_score as sk_roc_auc_score -from torchmetrics.classification.auroc import AUROC -from torchmetrics.functional import auroc -from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.auroc import BinaryAUROC, MulticlassAUROC, MultilabelAUROC +from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc +from torchmetrics.functional.classification.roc import binary_roc +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_LOWER_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_auroc_binary_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - # todo: `multi_class` is unused - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) - - -def _sk_auroc_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.reshape(-1, num_classes).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES), - ], -) -class TestAUROC(MetricTester): - @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) +def _sk_auroc_binary(preds, target, max_fpr=None, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_roc_auc_score(target, preds, max_fpr=max_fpr) + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryAUROC(MetricTester): + @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip("max_fpr parameter not support for multi class or multi label") + def test_binary_auroc(self, input, ddp, max_fpr, ignore_index): + if max_fpr is not None and _TORCH_LOWER_1_6: + pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryAUROC, + sk_metric=partial(_sk_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + metric_args={ + "max_fpr": max_fpr, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) - # max_fpr only supported for torch v1.6 or higher + @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_binary_auroc_functional(self, input, max_fpr, ignore_index): if max_fpr is not None and _TORCH_LOWER_1_6: pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_auroc, + sk_metric=partial(_sk_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + metric_args={ + "max_fpr": max_fpr, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) - # average='micro' only supported for multilabel - if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: - pytest.skip("micro argument only support for multilabel input") + def test_binary_auroc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_auroc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_auroc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_auroc_threshold_arg(self, input, threshold_fn): + preds, target = input + + for pred, true in zip(preds, target): + _, _, t = binary_roc(pred, true, thresholds=None) + ap1 = binary_auroc(pred, true, thresholds=None) + ap2 = binary_auroc(pred, true, thresholds=threshold_fn(t.flip(0))) + assert torch.allclose(ap1, ap2) + + +def _sk_auroc_multiclass(preds, target, average="macro", ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_roc_auc_score(target, preds, average=average, multi_class="ovr", labels=list(range(NUM_CLASSES))) + + +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassAUROC(MetricTester): + @pytest.mark.parametrize("average", ["macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_auroc(self, input, average, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=AUROC, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, + metric_class=MulticlassAUROC, + sk_metric=partial(_sk_auroc_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, ) - @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) - def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip("max_fpr parameter not support for multi class or multi label") + @pytest.mark.parametrize("average", ["macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_auroc_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_auroc, + sk_metric=partial(_sk_auroc_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and _TORCH_LOWER_1_6: - pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + def test_multiclass_auroc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassAUROC, + metric_functional=multiclass_auroc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) - # average='micro' only supported for multilabel - if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: - pytest.skip("micro argument only support for multilabel input") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_auroc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassAUROC, + metric_functional=multiclass_auroc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) - self.run_functional_metric_test( - preds, - target, - metric_functional=auroc, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_auroc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassAUROC, + metric_functional=multiclass_auroc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, ) - def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, max_fpr): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip("max_fpr parameter not support for multi class or multi label") + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + def test_multiclass_auroc_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 2)) + 1e-6 # rounding will simulate binning + ap1 = multiclass_auroc(pred, true, num_classes=NUM_CLASSES, average=average, thresholds=None) + ap2 = multiclass_auroc( + pred, true, num_classes=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) + + +def _sk_auroc_multilabel(preds, target, average="macro", ignore_index=None): + if ignore_index is None: + if preds.ndim > 2: + target = target.transpose(2, 1).reshape(-1, NUM_CLASSES) + preds = preds.transpose(2, 1).reshape(-1, NUM_CLASSES) + target = target.numpy() + preds = preds.numpy() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + return sk_roc_auc_score(target, preds, average=average, max_fpr=None) + if average == "micro": + return _sk_auroc_binary(preds.flatten(), target.flatten(), max_fpr=None, ignore_index=ignore_index) + res = [] + for i in range(NUM_CLASSES): + res.append(_sk_auroc_binary(preds[:, i], target[:, i], max_fpr=None, ignore_index=ignore_index)) + if average == "macro": + return np.array(res)[~np.isnan(res)].mean() + if average == "weighted": + weights = ((target == 1).sum([0, 2]) if target.ndim == 3 else (target == 1).sum(0)).numpy() + weights = weights / sum(weights) + return (np.array(res) * weights)[~np.isnan(res)].sum() + return res - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and _TORCH_LOWER_1_6: - pytest.skip("requires torch v1.6 or higher to test max_fpr argument") +@pytest.mark.parametrize( + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) +) +class TestMultilabelAUROC(MetricTester): + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_auroc(self, input, ddp, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelAUROC, + sk_metric=partial(_sk_auroc_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multilabel_auroc_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_auroc, + sk_metric=partial(_sk_auroc_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_auroc_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=AUROC, - metric_functional=auroc, - metric_args={"num_classes": num_classes, "max_fpr": max_fpr}, + metric_module=MultilabelAUROC, + metric_functional=multilabel_auroc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_auroc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelAUROC, + metric_functional=multilabel_auroc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_auroc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelAUROC, + metric_functional=multilabel_auroc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -def test_error_on_different_mode(): - """test that an error is raised if the user pass in data of different modes (binary, multi-label, multi- - class)""" - metric = AUROC() - # pass in multi-class data - metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,))) - with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): - # pass in multi-label data - metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) - - -def test_error_multiclass_no_num_classes(): - with pytest.raises( - ValueError, match="Detected input to `multiclass` but you did not provide `num_classes` argument" - ): - _ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20,))) - - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_weighted_with_empty_classes(device): - """Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists. - - Tests that the proper warnings and errors are raised - """ - if not torch.cuda.is_available() and device == "cuda": - pytest.skip("Test requires gpu to run") - - preds = torch.tensor( - [ - [0.90, 0.05, 0.05], - [0.05, 0.90, 0.05], - [0.05, 0.05, 0.90], - [0.85, 0.05, 0.10], - [0.10, 0.10, 0.80], - ] - ).to(device) - target = torch.tensor([0, 1, 1, 2, 2]).to(device) - num_classes = 3 - _auroc = auroc(preds, target, average="weighted", num_classes=num_classes) - - # Add in a class with zero observations at second to last index - preds = torch.cat( - (preds[:, : num_classes - 1], torch.rand_like(preds[:, 0:1]), preds[:, num_classes - 1 :]), axis=1 - ) - # Last class (2) gets moved to 3 - target[target == num_classes - 1] = num_classes - with pytest.warns(UserWarning, match="Class 2 had 0 observations, omitted from AUROC calculation"): - _auroc_empty_class = auroc(preds, target, average="weighted", num_classes=num_classes + 1) - assert _auroc == _auroc_empty_class - - target = torch.zeros_like(target) - with pytest.raises(ValueError, match="Found 1 non-empty class in `multiclass` AUROC calculation"): - _ = auroc(preds, target, average="weighted", num_classes=num_classes + 1) - - -def test_warnings_on_missing_class(): - """Test that a warning is given if either the positive or negative class is missing.""" - metric = AUROC() - # no positive samples - warning = ( - "No positive samples in targets, true positive value should be meaningless." - " Returning zero tensor in true positive score" - ) - with pytest.warns(UserWarning, match=warning): - score = metric(torch.randn(10).sigmoid(), torch.zeros(10).int()) - assert score == 0 - - warning = ( - "No negative samples in targets, false positive value should be meaningless." - " Returning zero tensor in false positive score" - ) - with pytest.warns(UserWarning, match=warning): - score = metric(torch.randn(10).sigmoid(), torch.ones(10).int()) - assert score == 0 + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_auroc_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + ap1 = multilabel_auroc(pred, true, num_labels=NUM_CLASSES, average=average, thresholds=None) + ap2 = multilabel_auroc( + pred, true, num_labels=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) + + +# -------------------------- Old stuff -------------------------- + + +# def _sk_auroc_binary_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): +# # todo: `multi_class` is unused +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() +# return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) + + +# def _sk_auroc_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): +# sk_preds = preds.reshape(-1, num_classes).numpy() +# sk_target = target.view(-1).numpy() +# return sk_roc_auc_score( +# y_true=sk_target, +# y_score=sk_preds, +# average=average, +# max_fpr=max_fpr, +# multi_class=multi_class, +# ) + + +# def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): +# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# sk_target = target.view(-1).numpy() +# return sk_roc_auc_score( +# y_true=sk_target, +# y_score=sk_preds, +# average=average, +# max_fpr=max_fpr, +# multi_class=multi_class, +# ) + + +# def _sk_auroc_multilabel_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): +# sk_preds = preds.reshape(-1, num_classes).numpy() +# sk_target = target.reshape(-1, num_classes).numpy() +# return sk_roc_auc_score( +# y_true=sk_target, +# y_score=sk_preds, +# average=average, +# max_fpr=max_fpr, +# multi_class=multi_class, +# ) + + +# def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): +# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# return sk_roc_auc_score( +# y_true=sk_target, +# y_score=sk_preds, +# average=average, +# max_fpr=max_fpr, +# multi_class=multi_class, +# ) + + +# @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), +# (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES), +# ], +# ) +# class TestAUROC(MetricTester): +# @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): +# # max_fpr different from None is not support in multi class +# if max_fpr is not None and num_classes != 1: +# pytest.skip("max_fpr parameter not support for multi class or multi label") + +# # max_fpr only supported for torch v1.6 or higher +# if max_fpr is not None and _TORCH_LOWER_1_6: +# pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + +# # average='micro' only supported for multilabel +# if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: +# pytest.skip("micro argument only support for multilabel input") + +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=AUROC, +# sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, +# ) + +# @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) +# def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): +# # max_fpr different from None is not support in multi class +# if max_fpr is not None and num_classes != 1: +# pytest.skip("max_fpr parameter not support for multi class or multi label") + +# # max_fpr only supported for torch v1.6 or higher +# if max_fpr is not None and _TORCH_LOWER_1_6: +# pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + +# # average='micro' only supported for multilabel +# if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: +# pytest.skip("micro argument only support for multilabel input") + +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=auroc, +# sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), +# metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, +# ) + +# def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, max_fpr): +# # max_fpr different from None is not support in multi class +# if max_fpr is not None and num_classes != 1: +# pytest.skip("max_fpr parameter not support for multi class or multi label") + +# # max_fpr only supported for torch v1.6 or higher +# if max_fpr is not None and _TORCH_LOWER_1_6: +# pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=AUROC, +# metric_functional=auroc, +# metric_args={"num_classes": num_classes, "max_fpr": max_fpr}, +# ) + + +# def test_error_on_different_mode(): +# """test that an error is raised if the user pass in data of different modes (binary, multi-label, multi- +# class)""" +# metric = AUROC() +# # pass in multi-class data +# metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,))) +# with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): +# # pass in multi-label data +# metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) + + +# def test_error_multiclass_no_num_classes(): +# with pytest.raises( +# ValueError, match="Detected input to `multiclass` but you did not provide `num_classes` argument" +# ): +# _ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20,))) + + +# @pytest.mark.parametrize("device", ["cpu", "cuda"]) +# def test_weighted_with_empty_classes(device): +# """Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists. + +# Tests that the proper warnings and errors are raised +# """ +# if not torch.cuda.is_available() and device == "cuda": +# pytest.skip("Test requires gpu to run") + +# preds = torch.tensor( +# [ +# [0.90, 0.05, 0.05], +# [0.05, 0.90, 0.05], +# [0.05, 0.05, 0.90], +# [0.85, 0.05, 0.10], +# [0.10, 0.10, 0.80], +# ] +# ).to(device) +# target = torch.tensor([0, 1, 1, 2, 2]).to(device) +# num_classes = 3 +# _auroc = auroc(preds, target, average="weighted", num_classes=num_classes) + +# # Add in a class with zero observations at second to last index +# preds = torch.cat( +# (preds[:, : num_classes - 1], torch.rand_like(preds[:, 0:1]), preds[:, num_classes - 1 :]), axis=1 +# ) +# # Last class (2) gets moved to 3 +# target[target == num_classes - 1] = num_classes +# with pytest.warns(UserWarning, match="Class 2 had 0 observations, omitted from AUROC calculation"): +# _auroc_empty_class = auroc(preds, target, average="weighted", num_classes=num_classes + 1) +# assert _auroc == _auroc_empty_class + +# target = torch.zeros_like(target) +# with pytest.raises(ValueError, match="Found 1 non-empty class in `multiclass` AUROC calculation"): +# _ = auroc(preds, target, average="weighted", num_classes=num_classes + 1) + + +# def test_warnings_on_missing_class(): +# """Test that a warning is given if either the positive or negative class is missing.""" +# metric = AUROC() +# # no positive samples +# warning = ( +# "No positive samples in targets, true positive value should be meaningless." +# " Returning zero tensor in true positive score" +# ) +# with pytest.warns(UserWarning, match=warning): +# score = metric(torch.randn(10).sigmoid(), torch.zeros(10).int()) +# assert score == 0 + +# warning = ( +# "No negative samples in targets, false positive value should be meaningless." +# " Returning zero tensor in false positive score" +# ) +# with pytest.warns(UserWarning, match=warning): +# score = metric(torch.randn(10).sigmoid(), torch.ones(10).int()) +# assert score == 0 diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index 7938486b2c0..bcc3d6d43d6 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -15,157 +15,501 @@ import numpy as np import pytest +import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import average_precision_score as sk_average_precision_score -from torch import tensor - -from torchmetrics.classification.avg_precision import AveragePrecision -from torchmetrics.functional import average_precision -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel + +from torchmetrics.classification.average_precision import ( + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) +from torchmetrics.functional.classification.average_precision import ( + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) +from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_average_precision_score(y_true, probas_pred, num_classes=1, average=None): - if num_classes == 1: - return sk_average_precision_score(y_true, probas_pred) +def _sk_average_precision_binary(preds, target, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_average_precision_score(target, preds) - res = [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryAveragePrecision(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_average_precision(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryAveragePrecision, + sk_metric=partial(_sk_average_precision_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_average_precision_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_average_precision, + sk_metric=partial(_sk_average_precision_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_average_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryAveragePrecision, + metric_functional=binary_average_precision, + metric_args={"thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_average_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryAveragePrecision, + metric_functional=binary_average_precision, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_average_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryAveragePrecision, + metric_functional=binary_average_precision, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_average_precision_threshold_arg(self, input, threshold_fn): + preds, target = input + + for pred, true in zip(preds, target): + _, _, t = binary_precision_recall_curve(pred, true, thresholds=None) + ap1 = binary_average_precision(pred, true, thresholds=None) + ap2 = binary_average_precision(pred, true, thresholds=threshold_fn(t)) + assert torch.allclose(ap1, ap2) + + +def _sk_average_precision_multiclass(preds, target, average="macro", ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + + res = [] + for i in range(NUM_CLASSES): + y_true_temp = np.zeros_like(target) + y_true_temp[target == i] = 1 + res.append(sk_average_precision_score(y_true_temp, preds[:, i])) if average == "macro": - return np.array(res).mean() + return np.array(res)[~np.isnan(res)].mean() if average == "weighted": - weights = np.bincount(y_true) if y_true.max() > 1 else y_true.sum(axis=0) + weights = np.bincount(target) weights = weights / sum(weights) - return (np.array(res) * weights).sum() - + return (np.array(res) * weights)[~np.isnan(res)].sum() return res -def _sk_avg_prec_binary_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassAveragePrecision(MetricTester): + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_average_precision(self, input, average, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassAveragePrecision, + sk_metric=partial(_sk_average_precision_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_average_precision_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_average_precision, + sk_metric=partial(_sk_average_precision_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) -def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() + def test_multiclass_average_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassAveragePrecision, + metric_functional=multiclass_average_precision, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_average_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassAveragePrecision, + metric_functional=multiclass_average_precision, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_average_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassAveragePrecision, + metric_functional=multiclass_average_precision, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) -def _sk_avg_prec_multilabel_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1, num_classes).numpy() - return sk_average_precision_score(sk_target, sk_preds, average=average) + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + def test_multiclass_average_precision_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 2)) + 1e-6 # rounding will simulate binning + ap1 = multiclass_average_precision(pred, true, num_classes=NUM_CLASSES, average=average, thresholds=None) + ap2 = multiclass_average_precision( + pred, true, num_classes=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) -def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) +def _sk_average_precision_multilabel(preds, target, average="macro", ignore_index=None): + if average == "micro": + return _sk_average_precision_binary(preds.flatten(), target.flatten(), ignore_index) + res = [] + for i in range(NUM_CLASSES): + res.append(_sk_average_precision_binary(preds[:, i], target[:, i], ignore_index)) + if average == "macro": + return np.array(res)[~np.isnan(res)].mean() + if average == "weighted": + weights = ((target == 1).sum([0, 2]) if target.ndim == 3 else (target == 1).sum(0)).numpy() + weights = weights / sum(weights) + return (np.array(res) * weights)[~np.isnan(res)].sum() + return res @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), - (_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES), - ], + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) -class TestAveragePrecision(MetricTester): +class TestMultilabelAveragePrecision(MetricTester): @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step): - if target.max() > 1 and average == "micro": - pytest.skip("average=micro and multiclass input cannot be used together") - + def test_multilabel_average_precision(self, input, ddp, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=AveragePrecision, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "average": average}, + metric_class=MultilabelAveragePrecision, + sk_metric=partial(_sk_average_precision_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, ) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average): - if target.max() > 1 and average == "micro": - pytest.skip("average=micro and multiclass input cannot be used together") - + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multilabel_average_precision_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, - metric_functional=average_precision, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average), - metric_args={"num_classes": num_classes, "average": average}, + metric_functional=multilabel_average_precision, + sk_metric=partial(_sk_average_precision_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, ) - def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes): + def test_multiclass_average_precision_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=AveragePrecision, - metric_functional=average_precision, - metric_args={"num_classes": num_classes}, + metric_module=MultilabelAveragePrecision, + metric_functional=multilabel_average_precision, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_average_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelAveragePrecision, + metric_functional=multilabel_average_precision, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_average_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelAveragePrecision, + metric_functional=multilabel_average_precision, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["scores", "target", "expected_score"], - [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - (tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), 0.25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - (tensor([0.6, 0.7, 0.8, 9]), tensor([1, 0, 0, 1]), 0.75), - ], -) -def test_average_precision(scores, target, expected_score): - assert average_precision(scores, target) == expected_score - - -def test_average_precision_warnings_and_errors(): - """Test that the correct errors and warnings gets raised.""" - - # check average argument - with pytest.raises(ValueError, match="Expected argument `average` to be one .*"): - AveragePrecision(num_classes=5, average="samples") - - # check that micro average cannot be used with multilabel input - pred = tensor( - [ - [0.75, 0.05, 0.05, 0.05, 0.05], - [0.05, 0.75, 0.05, 0.05, 0.05], - [0.05, 0.05, 0.75, 0.05, 0.05], - [0.05, 0.05, 0.05, 0.75, 0.05], - ] - ) - target = tensor([0, 1, 3, 2]) - average_precision = AveragePrecision(num_classes=5, average="micro") - with pytest.raises(ValueError, match="Cannot use `micro` average with multi-class input"): - average_precision(pred, target) - - # check that warning is thrown when average=macro and nan is encoutered in individual scores - average_precision = AveragePrecision(num_classes=5, average="macro") - with pytest.warns(UserWarning, match="Average precision score for one or more classes was `nan`.*"): - average_precision(pred, target) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_average_precision_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + ap1 = multilabel_average_precision(pred, true, num_labels=NUM_CLASSES, average=average, thresholds=None) + ap2 = multilabel_average_precision( + pred, true, num_labels=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) + + +# -------------------------- Old stuff -------------------------- + + +# def _sk_average_precision_score(y_true, probas_pred, num_classes=1, average=None): +# if num_classes == 1: +# return sk_average_precision_score(y_true, probas_pred) + +# res = [] +# for i in range(num_classes): +# y_true_temp = np.zeros_like(y_true) +# y_true_temp[y_true == i] = 1 +# res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) + +# if average == "macro": +# return np.array(res).mean() +# if average == "weighted": +# weights = np.bincount(y_true) if y_true.max() > 1 else y_true.sum(axis=0) +# weights = weights / sum(weights) +# return (np.array(res) * weights).sum() + +# return res + + +# def _sk_avg_prec_binary_prob(preds, target, num_classes=1, average=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return _sk_average_precision_score( +# y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average +# ) + + +# def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1, average=None): +# sk_preds = preds.reshape(-1, num_classes).numpy() +# sk_target = target.view(-1).numpy() + +# return _sk_average_precision_score( +# y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average +# ) + + +# def _sk_avg_prec_multilabel_prob(preds, target, num_classes=1, average=None): +# sk_preds = preds.reshape(-1, num_classes).numpy() +# sk_target = target.view(-1, num_classes).numpy() +# return sk_average_precision_score(sk_target, sk_preds, average=average) + + +# def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=None): +# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# sk_target = target.view(-1).numpy() +# return _sk_average_precision_score( +# y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average +# ) + + +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), +# (_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES), +# ], +# ) +# class TestAveragePrecision(MetricTester): +# @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step): +# if target.max() > 1 and average == "micro": +# pytest.skip("average=micro and multiclass input cannot be used together") + +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=AveragePrecision, +# sk_metric=partial(sk_metric, num_classes=num_classes, average=average), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"num_classes": num_classes, "average": average}, +# ) + +# @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) +# def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average): +# if target.max() > 1 and average == "micro": +# pytest.skip("average=micro and multiclass input cannot be used together") + +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=average_precision, +# sk_metric=partial(sk_metric, num_classes=num_classes, average=average), +# metric_args={"num_classes": num_classes, "average": average}, +# ) + +# def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=AveragePrecision, +# metric_functional=average_precision, +# metric_args={"num_classes": num_classes}, +# ) + + +# @pytest.mark.parametrize( +# ["scores", "target", "expected_score"], +# [ +# # Check the average_precision_score of a constant predictor is +# # the TPR +# # Generate a dataset with 25% of positives +# # And a constant score +# # The precision is then the fraction of positive whatever the recall +# # is, as there is only one threshold: +# (tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), 0.25), +# # With threshold 0.8 : 1 TP and 2 TN and one FN +# (tensor([0.6, 0.7, 0.8, 9]), tensor([1, 0, 0, 1]), 0.75), +# ], +# ) +# def test_average_precision(scores, target, expected_score): +# assert average_precision(scores, target) == expected_score + + +# def test_average_precision_warnings_and_errors(): +# """Test that the correct errors and warnings gets raised.""" + +# # check average argument +# with pytest.raises(ValueError, match="Expected argument `average` to be one .*"): +# AveragePrecision(num_classes=5, average="samples") + +# # check that micro average cannot be used with multilabel input +# pred = tensor( +# [ +# [0.75, 0.05, 0.05, 0.05, 0.05], +# [0.05, 0.75, 0.05, 0.05, 0.05], +# [0.05, 0.05, 0.75, 0.05, 0.05], +# [0.05, 0.05, 0.05, 0.75, 0.05], +# ] +# ) +# target = tensor([0, 1, 3, 2]) +# average_precision = AveragePrecision(num_classes=5, average="micro") +# with pytest.raises(ValueError, match="Cannot use `micro` average with multi-class input"): +# average_precision(pred, target) + +# # check that warning is thrown when average=macro and nan is encoutered in individual scores +# average_precision = AveragePrecision(num_classes=5, average="macro") +# with pytest.warns(UserWarning, match="Average precision score for one or more classes was `nan`.*"): +# average_precision(pred, target) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index c56441a7f4c..a9d5330aaf8 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -16,131 +16,456 @@ import numpy as np import pytest import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve -from torch import Tensor, tensor - -from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve -from torchmetrics.functional import precision_recall_curve -from torchmetrics.functional.classification.precision_recall_curve import _binary_clf_curve -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob + +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.precision_recall_curve import ( + binary_precision_recall_curve, + multiclass_precision_recall_curve, + multilabel_precision_recall_curve, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): - """Adjusted comparison function that can also handles multiclass.""" - if num_classes == 1: - return sk_precision_recall_curve(y_true, probas_pred) +def _sk_precision_recall_curve_binary(preds, target, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_precision_recall_curve(target, preds) + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryPrecisionRecallCurve(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_precision_recall_curve(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryPrecisionRecallCurve, + sk_metric=partial(_sk_precision_recall_curve_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_precision_recall_curve_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_precision_recall_curve, + sk_metric=partial(_sk_precision_recall_curve_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_precision_recall_curve_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_precision_recall_curve_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_precision_recall_curve_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_precision_recall_curve_threshold_arg(self, input, threshold_fn): + preds, target = input + + for pred, true in zip(preds, target): + p1, r1, t1 = binary_precision_recall_curve(pred, true, thresholds=None) + p2, r2, t2 = binary_precision_recall_curve(pred, true, thresholds=threshold_fn(t1)) + + assert torch.allclose(p1, p2) + assert torch.allclose(r1, r2) + assert torch.allclose(t1, t2) + + +def _sk_precision_recall_curve_multiclass(preds, target, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) precision, recall, thresholds = [], [], [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = sk_precision_recall_curve(target_temp, preds[:, i]) precision.append(res[0]) recall.append(res[1]) thresholds.append(res[2]) - return precision, recall, thresholds + # return precision, recall, thresholds + return [np.nan_to_num(x, nan=0.0) for x in [precision, recall, thresholds]] -def _sk_prec_rc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassPrecisionRecallCurve(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_precision_recall_curve(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassPrecisionRecallCurve, + sk_metric=partial(_sk_precision_recall_curve_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_precision_recall_curve_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_precision_recall_curve, + sk_metric=partial(_sk_precision_recall_curve_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + def test_multiclass_precision_recall_curve_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassPrecisionRecallCurve, + metric_functional=multiclass_precision_recall_curve, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) -def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_curve_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassPrecisionRecallCurve, + metric_functional=multiclass_precision_recall_curve, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_curve_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassPrecisionRecallCurve, + metric_functional=multiclass_precision_recall_curve, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multiclass_precision_recall_curve_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multiclass_precision_recall_curve(pred, true, num_classes=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multiclass_precision_recall_curve( + pred, true, num_classes=NUM_CLASSES, thresholds=threshold_fn(t) + ) + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) -def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + +def _sk_precision_recall_curve_multilabel(preds, target, ignore_index=None): + precision, recall, thresholds = [], [], [] + for i in range(NUM_CLASSES): + res = _sk_precision_recall_curve_binary(preds[:, i], target[:, i], ignore_index) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), - ], + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) -class TestPrecisionRecallCurve(MetricTester): +class TestMultilabelPrecisionRecallCurve(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_multilabel_precision_recall_curve(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=PrecisionRecallCurve, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes}, + metric_class=MultilabelPrecisionRecallCurve, + sk_metric=partial(_sk_precision_recall_curve_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_precision_recall_curve_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=precision_recall_curve, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_functional=multilabel_precision_recall_curve, + sk_metric=partial(_sk_precision_recall_curve_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_precision_recall_curve_differentiability(self, preds, target, sk_metric, num_classes): + def test_multiclass_precision_recall_curve_differentiability(self, input): + preds, target = input self.run_differentiability_test( - preds, - target, - metric_module=PrecisionRecallCurve, - metric_functional=precision_recall_curve, - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_module=MultilabelPrecisionRecallCurve, + metric_functional=multilabel_precision_recall_curve, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_precision_recall_curve_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelPrecisionRecallCurve, + metric_functional=multilabel_precision_recall_curve, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["pred", "target", "expected_p", "expected_r", "expected_t"], - [([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1.0, 1.0], [1, 0.5, 0.5, 0.5, 0.0], [1, 2, 3, 4])], -) -def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(tensor(pred), tensor(target)) - assert p.size() == r.size() - assert p.size(0) == t.size(0) + 1 + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_curve_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelPrecisionRecallCurve, + metric_functional=multilabel_precision_recall_curve, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) - assert torch.allclose(p, tensor(expected_p).to(p)) - assert torch.allclose(r, tensor(expected_r).to(r)) - assert torch.allclose(t, tensor(expected_t).to(t)) + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multilabel_precision_recall_curve_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multilabel_precision_recall_curve(pred, true, num_labels=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multilabel_precision_recall_curve( + pred, true, num_labels=NUM_CLASSES, thresholds=threshold_fn(t) + ) + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) -@pytest.mark.parametrize( - "sample_weight, pos_label, exp_shape", - [(1, 1.0, 42), (None, 1.0, 42)], -) -def test_binary_clf_curve(sample_weight, pos_label, exp_shape): - # TODO: move back the pred and target to test func arguments - # if you fix the array inside the function, you'd also have fix the shape, - # because when the array changes, you also have to fix the shape - seed_all(0) - pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100 - target = tensor([0, 1] * 50, dtype=torch.int) - if sample_weight is not None: - sample_weight = torch.ones_like(pred) * sample_weight - - fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) - - assert isinstance(tps, Tensor) - assert isinstance(fps, Tensor) - assert isinstance(thresh, Tensor) - assert tps.shape == (exp_shape,) - assert fps.shape == (exp_shape,) - assert thresh.shape == (exp_shape,) + +# -------------------------- Old stuff -------------------------- + + +# def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): +# """Adjusted comparison function that can also handles multiclass.""" +# if num_classes == 1: +# return sk_precision_recall_curve(y_true, probas_pred) + +# precision, recall, thresholds = [], [], [] +# for i in range(num_classes): +# y_true_temp = np.zeros_like(y_true) +# y_true_temp[y_true == i] = 1 +# res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) +# precision.append(res[0]) +# recall.append(res[1]) +# thresholds.append(res[2]) +# return precision, recall, thresholds + + +# def _sk_prec_rc_binary_prob(preds, target, num_classes=1): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +# def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): +# sk_preds = preds.reshape(-1, num_classes).numpy() +# sk_target = target.view(-1).numpy() + +# return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +# def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): +# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# sk_target = target.view(-1).numpy() +# return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), +# ], +# ) +# class TestPrecisionRecallCurve(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=PrecisionRecallCurve, +# sk_metric=partial(sk_metric, num_classes=num_classes), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"num_classes": num_classes}, +# ) + +# def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=precision_recall_curve, +# sk_metric=partial(sk_metric, num_classes=num_classes), +# metric_args={"num_classes": num_classes}, +# ) + +# def test_precision_recall_curve_differentiability(self, preds, target, sk_metric, num_classes): +# self.run_differentiability_test( +# preds, +# target, +# metric_module=PrecisionRecallCurve, +# metric_functional=precision_recall_curve, +# metric_args={"num_classes": num_classes}, +# ) + + +# @pytest.mark.parametrize( +# ["pred", "target", "expected_p", "expected_r", "expected_t"], +# [([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1.0, 1.0], [1, 0.5, 0.5, 0.5, 0.0], [1, 2, 3, 4])], +# ) +# def test_pr_curve(pred, target, expected_p, expected_r, expected_t): +# p, r, t = precision_recall_curve(tensor(pred), tensor(target)) +# assert p.size() == r.size() +# assert p.size(0) == t.size(0) + 1 + +# assert torch.allclose(p, tensor(expected_p).to(p)) +# assert torch.allclose(r, tensor(expected_r).to(r)) +# assert torch.allclose(t, tensor(expected_t).to(t)) + + +# @pytest.mark.parametrize( +# "sample_weight, pos_label, exp_shape", +# [(1, 1.0, 42), (None, 1.0, 42)], +# ) +# def test_binary_clf_curve(sample_weight, pos_label, exp_shape): +# # TODO: move back the pred and target to test func arguments +# # if you fix the array inside the function, you'd also have fix the shape, +# # because when the array changes, you also have to fix the shape +# seed_all(0) +# pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100 +# target = tensor([0, 1] * 50, dtype=torch.int) +# if sample_weight is not None: +# sample_weight = torch.ones_like(pred) * sample_weight + +# fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + +# assert isinstance(tps, Tensor) +# assert isinstance(fps, Tensor) +# assert isinstance(thresh, Tensor) +# assert tps.shape == (exp_shape,) +# assert fps.shape == (exp_shape,) +# assert thresh.shape == (exp_shape,) diff --git a/tests/unittests/classification/test_recall_at_fixed_precision.py b/tests/unittests/classification/test_recall_at_fixed_precision.py new file mode 100644 index 00000000000..3032f757cde --- /dev/null +++ b/tests/unittests/classification/test_recall_at_fixed_precision.py @@ -0,0 +1,389 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax +from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve + +from torchmetrics.classification.recall_at_fixed_precision import ( + BinaryRecallAtFixedPrecision, + MulticlassRecallAtFixedPrecision, + MultilabelRecallAtFixedPrecision, +) +from torchmetrics.functional.classification.recall_at_fixed_precision import ( + binary_recall_at_fixed_precision, + multiclass_recall_at_fixed_precision, + multilabel_recall_at_fixed_precision, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index + +seed_all(42) + + +def recall_at_precision_x_multilabel(predictions, targets, min_precision): + precision, recall, thresholds = _sk_precision_recall_curve(targets, predictions) + + try: + tuple_all = [(r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision] + max_recall, _, best_threshold = max(tuple_all) + except ValueError: + max_recall, best_threshold = 0, 1e6 + + return float(max_recall), float(best_threshold) + + +def _sk_recall_at_fixed_precision_binary(preds, target, min_precision, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return recall_at_precision_x_multilabel(preds, target, min_precision) + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryRecallAtFixedPrecision(MetricTester): + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.85]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_recall_at_fixed_precision(self, input, ddp, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryRecallAtFixedPrecision, + sk_metric=partial( + _sk_recall_at_fixed_precision_binary, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_recall_at_fixed_precision_functional(self, input, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_recall_at_fixed_precision, + sk_metric=partial( + _sk_recall_at_fixed_precision_binary, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_recall_at_fixed_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryRecallAtFixedPrecision, + metric_functional=binary_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_recall_at_fixed_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryRecallAtFixedPrecision, + metric_functional=binary_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_recall_at_fixed_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryRecallAtFixedPrecision, + metric_functional=binary_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_binary_recall_at_fixed_precision_threshold_arg(self, input, min_precision): + preds, target = input + + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = binary_recall_at_fixed_precision(pred, true, min_precision=min_precision, thresholds=None) + r2, _ = binary_recall_at_fixed_precision( + pred, true, min_precision=min_precision, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(r1, r2) + + +def _sk_recall_at_fixed_precision_multiclass(preds, target, min_precision, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + + recall, thresholds = [], [] + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = recall_at_precision_x_multilabel(preds[:, i], target_temp, min_precision) + recall.append(res[0]) + thresholds.append(res[1]) + return recall, thresholds + + +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassRecallAtFixedPrecision(MetricTester): + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_recall_at_fixed_precision(self, input, ddp, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassRecallAtFixedPrecision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multiclass, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_recall_at_fixed_precision_functional(self, input, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_recall_at_fixed_precision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multiclass, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_recall_at_fixed_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassRecallAtFixedPrecision, + metric_functional=multiclass_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_recall_at_fixed_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassRecallAtFixedPrecision, + metric_functional=multiclass_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_recall_at_fixed_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassRecallAtFixedPrecision, + metric_functional=multiclass_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multiclass_recall_at_fixed_precision_threshold_arg(self, input, min_precision): + preds, target = input + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multiclass_recall_at_fixed_precision( + pred, true, num_classes=NUM_CLASSES, min_precision=min_precision, thresholds=None + ) + r2, _ = multiclass_recall_at_fixed_precision( + pred, true, num_classes=NUM_CLASSES, min_precision=min_precision, thresholds=torch.linspace(0, 1, 100) + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) + + +def _sk_recall_at_fixed_precision_multilabel(preds, target, min_precision, ignore_index=None): + recall, thresholds = [], [] + for i in range(NUM_CLASSES): + res = _sk_recall_at_fixed_precision_binary(preds[:, i], target[:, i], min_precision, ignore_index) + recall.append(res[0]) + thresholds.append(res[1]) + return recall, thresholds + + +@pytest.mark.parametrize( + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) +) +class TestMultilabelRecallAtFixedPrecision(MetricTester): + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_recall_at_fixed_precision(self, input, ddp, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelRecallAtFixedPrecision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multilabel, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_recall_at_fixed_precision_functional(self, input, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_recall_at_fixed_precision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multilabel, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_recall_at_fixed_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelRecallAtFixedPrecision, + metric_functional=multilabel_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_recall_at_fixed_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelRecallAtFixedPrecision, + metric_functional=multilabel_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_recall_at_fixed_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelRecallAtFixedPrecision, + metric_functional=multilabel_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multilabel_recall_at_fixed_precision_threshold_arg(self, input, min_precision): + preds, target = input + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multilabel_recall_at_fixed_precision( + pred, true, num_labels=NUM_CLASSES, min_precision=min_precision, thresholds=None + ) + r2, _ = multilabel_recall_at_fixed_precision( + pred, true, num_labels=NUM_CLASSES, min_precision=min_precision, thresholds=torch.linspace(0, 1, 100) + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index b5b787e9221..30e66bd042c 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -16,152 +16,464 @@ import numpy as np import pytest import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import roc_curve as sk_roc_curve -from torch import tensor - -from torchmetrics.classification.roc import ROC -from torchmetrics.functional import roc -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob + +from torchmetrics.classification.roc import BinaryROC, MulticlassROC, MultilabelROC +from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False): - """Adjusted comparison function that can also handles multiclass.""" - if num_classes == 1: - return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) +def _sk_roc_binary(preds, target, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + fpr, tpr, thresholds = sk_roc_curve(target, preds, drop_intermediate=False) + thresholds[0] = 1.0 + return [np.nan_to_num(x, nan=0.0) for x in [fpr, tpr, thresholds]] + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryROC(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_roc(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryROC, + sk_metric=partial(_sk_roc_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_roc_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_roc, + sk_metric=partial(_sk_roc_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_roc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryROC, + metric_functional=binary_roc, + metric_args={"thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_roc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryROC, + metric_functional=binary_roc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_roc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryROC, + metric_functional=binary_roc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_roc_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = binary_roc(pred, true, thresholds=None) + p2, r2, t2 = binary_roc(pred, true, thresholds=threshold_fn(t1.flip(0))) + assert torch.allclose(p1, p2) + assert torch.allclose(r1, r2) + assert torch.allclose(t1, t2) + + +def _sk_roc_multiclass(preds, target, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) fpr, tpr, thresholds = [], [], [] - for i in range(num_classes): - if multilabel: - y_true_temp = y_true[:, i] - else: - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - - res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = sk_roc_curve(target_temp, preds[:, i], drop_intermediate=False) + res[2][0] = 1.0 + fpr.append(res[0]) tpr.append(res[1]) thresholds.append(res[2]) - return fpr, tpr, thresholds - + return [np.nan_to_num(x, nan=0.0) for x in [fpr, tpr, thresholds]] -def _sk_roc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassROC(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_roc(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassROC, + sk_metric=partial(_sk_roc_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) -def _sk_roc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_roc_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_roc, + sk_metric=partial(_sk_roc_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + def test_multiclass_roc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassROC, + metric_functional=multiclass_roc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_roc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassROC, + metric_functional=multiclass_roc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) -def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_roc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassROC, + metric_functional=multiclass_roc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multiclass_roc_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multiclass_roc(pred, true, num_classes=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multiclass_roc(pred, true, num_classes=NUM_CLASSES, thresholds=threshold_fn(t.flip(0))) -def _sk_roc_multilabel_prob(preds, target, num_classes=1): - sk_preds = preds.numpy() - sk_target = target.numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) -def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) +def _sk_roc_multilabel(preds, target, ignore_index=None): + fpr, tpr, thresholds = [], [], [] + for i in range(NUM_CLASSES): + res = _sk_roc_binary(preds[:, i], target[:, i], ignore_index) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_roc_multilabel_multidim_prob, NUM_CLASSES), - ], + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) -class TestROC(MetricTester): +class TestMultilabelROC(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_multilabel_roc(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=ROC, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes}, + metric_class=MultilabelROC, + sk_metric=partial(_sk_roc_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_roc_functional(self, preds, target, sk_metric, num_classes): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_roc_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=roc, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_functional=multilabel_roc, + sk_metric=partial(_sk_roc_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_roc_differentiability(self, preds, target, sk_metric, num_classes): + def test_multiclass_roc_differentiability(self, input): + preds, target = input self.run_differentiability_test( - preds, - target, - metric_module=ROC, - metric_functional=roc, - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_module=MultilabelROC, + metric_functional=multilabel_roc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_roc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelROC, + metric_functional=multilabel_roc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["pred", "target", "expected_tpr", "expected_fpr"], - [ - ([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), - ([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), - ([1, 1], [1, 0], [0, 1], [0, 1]), - ([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), - ([0.5, 0.5], [0, 1], [0, 1], [0, 1]), - ], -) -def test_roc_curve(pred, target, expected_tpr, expected_fpr): - fpr, tpr, thresh = roc(tensor(pred), tensor(target)) - - assert fpr.shape == tpr.shape - assert fpr.size(0) == thresh.size(0) - assert torch.allclose(fpr, tensor(expected_fpr).to(fpr)) - assert torch.allclose(tpr, tensor(expected_tpr).to(tpr)) - - -def test_warnings_on_missing_class(): - """Test that a warning is given if either the positive or negative class is missing.""" - metric = ROC() - # no positive samples - warning = ( - "No positive samples in targets, true positive value should be meaningless." - " Returning zero tensor in true positive score" - ) - with pytest.warns(UserWarning, match=warning): - _, tpr, _ = metric(torch.randn(10).sigmoid(), torch.zeros(10)) - assert all(tpr == 0) - - warning = ( - "No negative samples in targets, false positive value should be meaningless." - " Returning zero tensor in false positive score" - ) - with pytest.warns(UserWarning, match=warning): - fpr, _, _ = metric(torch.randn(10).sigmoid(), torch.ones(10)) - assert all(fpr == 0) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_roc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelROC, + metric_functional=multilabel_roc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multilabel_roc_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multilabel_roc(pred, true, num_labels=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multilabel_roc(pred, true, num_labels=NUM_CLASSES, thresholds=threshold_fn(t.flip(0))) + + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) + + +# -------------------------- Old stuff -------------------------- + +# def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False): +# """Adjusted comparison function that can also handles multiclass.""" +# if num_classes == 1: +# return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) + +# fpr, tpr, thresholds = [], [], [] +# for i in range(num_classes): +# if multilabel: +# y_true_temp = y_true[:, i] +# else: +# y_true_temp = np.zeros_like(y_true) +# y_true_temp[y_true == i] = 1 + +# res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) +# fpr.append(res[0]) +# tpr.append(res[1]) +# thresholds.append(res[2]) +# return fpr, tpr, thresholds + + +# def _sk_roc_binary_prob(preds, target, num_classes=1): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +# def _sk_roc_multiclass_prob(preds, target, num_classes=1): +# sk_preds = preds.reshape(-1, num_classes).numpy() +# sk_target = target.view(-1).numpy() + +# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +# def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): +# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# sk_target = target.view(-1).numpy() +# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +# def _sk_roc_multilabel_prob(preds, target, num_classes=1): +# sk_preds = preds.numpy() +# sk_target = target.numpy() +# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) + + +# def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): +# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() +# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) + + +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), +# (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_roc_multilabel_multidim_prob, NUM_CLASSES), +# ], +# ) +# class TestROC(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=ROC, +# sk_metric=partial(sk_metric, num_classes=num_classes), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"num_classes": num_classes}, +# ) + +# def test_roc_functional(self, preds, target, sk_metric, num_classes): +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=roc, +# sk_metric=partial(sk_metric, num_classes=num_classes), +# metric_args={"num_classes": num_classes}, +# ) + +# def test_roc_differentiability(self, preds, target, sk_metric, num_classes): +# self.run_differentiability_test( +# preds, +# target, +# metric_module=ROC, +# metric_functional=roc, +# metric_args={"num_classes": num_classes}, +# ) + + +# @pytest.mark.parametrize( +# ["pred", "target", "expected_tpr", "expected_fpr"], +# [ +# ([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), +# ([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), +# ([1, 1], [1, 0], [0, 1], [0, 1]), +# ([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), +# ([0.5, 0.5], [0, 1], [0, 1], [0, 1]), +# ], +# ) +# def test_roc_curve(pred, target, expected_tpr, expected_fpr): +# fpr, tpr, thresh = roc(tensor(pred), tensor(target)) + +# assert fpr.shape == tpr.shape +# assert fpr.size(0) == thresh.size(0) +# assert torch.allclose(fpr, tensor(expected_fpr).to(fpr)) +# assert torch.allclose(tpr, tensor(expected_tpr).to(tpr)) + + +# def test_warnings_on_missing_class(): +# """Test that a warning is given if either the positive or negative class is missing.""" +# metric = ROC() +# # no positive samples +# warning = ( +# "No positive samples in targets, true positive value should be meaningless." +# " Returning zero tensor in true positive score" +# ) +# with pytest.warns(UserWarning, match=warning): +# _, tpr, _ = metric(torch.randn(10).sigmoid(), torch.zeros(10)) +# assert all(tpr == 0) + +# warning = ( +# "No negative samples in targets, false positive value should be meaningless." +# " Returning zero tensor in false positive score" +# ) +# with pytest.warns(UserWarning, match=warning): +# fpr, _, _ = metric(torch.randn(10).sigmoid(), torch.ones(10)) +# assert all(fpr == 0) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 1517064f5ab..d7fa26b571d 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -631,11 +631,19 @@ def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: """Utility function for injecting the ignore index value into a tensor randomly.""" if any(x.flatten() == ignore_index): # ignore index is a class label return x + classes = torch.unique(x) idx = torch.randperm(x.numel()) x = deepcopy(x) # randomly set either element {9, 10} to the ignore index value skip = torch.randint(9, 11, (1,)).item() x.view(-1)[idx[::skip]] = ignore_index + # if we accedently removed a class completly in a batch, reintroduce it again + for batch in x: + new_classes = torch.unique(batch) + class_not_in = [c not in new_classes for c in classes] + if any(class_not_in): + missing_class = int(np.where(class_not_in)[0][0]) + batch[torch.where(batch == ignore_index)[0][0]] = missing_class return x