diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a54695e838..41481bfc069 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of metrics through `.plot()` method ( [#1328](https://github.com/Lightning-AI/metrics/pull/1328), [#1481](https://github.com/Lightning-AI/metrics/pull/1481), - [#1480](https://github.com/Lightning-AI/metrics/pull/1480) + [#1480](https://github.com/Lightning-AI/metrics/pull/1480), + [#1490](https://github.com/Lightning-AI/metrics/pull/1490), ) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 0f09abbabd4..c8bb0941186 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -32,6 +32,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"] class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -93,6 +98,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False + plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0} def __init__( self, @@ -112,6 +118,42 @@ def compute(self) -> Tensor: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_auroc_compute(state, self.thresholds, self.max_fpr) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn, randint + >>> import torch.nn.functional as F + >>> # Example plotting a combined value across all classes + >>> from torchmetrics.classification import BinaryAUROC + >>> preds = F.softmax(randn(20, 2), dim=1) + >>> target = randint(2, (20,)) + >>> metric = BinaryAUROC() + >>> metric.update(preds[:, 1], target) + >>> fig_, ax_ = metric.plot() + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + class MulticlassAUROC(MulticlassPrecisionRecallCurve): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC @@ -189,6 +231,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Class"} def __init__( self, @@ -212,6 +255,39 @@ def compute(self) -> Tensor: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn, randint + >>> # Example plotting a combined value across all classes + >>> from torchmetrics.classification import MulticlassAUROC + >>> metric = MulticlassAUROC(num_classes=3, average="macro") + >>> metric.update(randn(20, 3), randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + class MultilabelAUROC(MultilabelPrecisionRecallCurve): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. The AUROC @@ -291,6 +367,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Class"} def __init__( self, @@ -314,6 +391,46 @@ def compute(self) -> Tensor: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import tensor + >>> from torchmetrics.classification import MultilabelAUROC + >>> preds = tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=None) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + class AUROC: r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). The AUROC score summarizes the diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 12c5ae03888..8340a0f1c9d 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union from torch import Tensor from typing_extensions import Literal @@ -21,6 +21,14 @@ MulticlassPrecisionRecallCurve, MultilabelPrecisionRecallCurve, ) +from torchmetrics.functional.classification.auroc import ( + _binary_auroc_arg_validation, + _binary_auroc_compute, + _multiclass_auroc_arg_validation, + _multiclass_auroc_compute, + _multilabel_auroc_arg_validation, + _multilabel_auroc_compute, +) from torchmetrics.functional.classification.roc import ( _binary_roc_compute, _multiclass_roc_compute, @@ -29,6 +37,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_binary_roc_curve + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinaryROC.plot"] class BinaryROC(BinaryPrecisionRecallCurve): @@ -107,6 +120,53 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_roc_compute(state, self.thresholds) + def __compute_auroc(self) -> Tensor: + """Computes Area under ROC curve from BinaryAUROC metric to show AUROC value together with ROC plot.""" + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat + return _binary_auroc_compute(state, self.thresholds, max_fpr=None) + + def plot( + self, + fpr: Optional[Union[Tensor, Sequence[Tensor]]] = None, + tpr: Optional[Union[Tensor, Sequence[Tensor]]] = None, + ax: Optional[_AX_TYPE] = None, + name: Optional[str] = None, + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + fpr: False Postive Rate provided by calling `metric.forward` or `metric.compute` + tpr: True Postive Rate provided by calling `metric.forward` or `metric.compute` + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + name: Custom name to describe the classifier + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn, randint + >>> import torch.nn.functional as F + >>> from torchmetrics.classification import BinaryROC + >>> preds = F.softmax(randn(20, 2), dim=1) + >>> target = randint(2, (20,)) + >>> metric = BinaryROC() + >>> metric.update(preds[:, 1], target) + >>> fig_, ax_ = metric.plot() + """ + if fpr is None or tpr is None: + fpr, tpr, _ = self.compute() + roc_auc = self.__compute_auroc() + name = self.__class__.__name__ if name is None else name + fig, ax = plot_binary_roc_curve(tpr, fpr, ax=ax, roc_auc=roc_auc, name=name) + return fig, ax + class MulticlassROC(MulticlassPrecisionRecallCurve): r"""Compute the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple pairs of diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 70fa000c99f..dc10667f32a 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -13,7 +13,7 @@ # limitations under the License. from itertools import product from math import ceil, floor, sqrt -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -213,3 +213,54 @@ def plot_confusion_matrix( ax.text(jj, ii, str(val.item()), ha="center", va="center", fontsize=15) return fig, axs + + +def plot_binary_roc_curve( + tpr: Tensor, + fpr: Tensor, + ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] + roc_auc: Optional[Union[float, Tensor]] = None, + name: Optional[str] = None, + **kwargs: Any, +) -> _PLOT_OUT_TYPE: + """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py. + + Plots the roc curve + + Args: + tpr: Tensor containing the values for True Positive Rate + fpr: Tensor containing the values for False Positive Rate + ax: Axis from a figure + roc_auc: AUROC score (computed separately) + name: Custom name to describe the classifier + kwargs: additional keyword arguments for line drawing + + Returns: + A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + """ + _error_on_missing_matplotlib() + fig, ax = plt.subplots() if ax is None else (None, ax) + + if isinstance(roc_auc, Tensor): + assert roc_auc.numel() == 1, "roc_auc Tensor must consist of only one element" + roc_auc = roc_auc.item() + + line_kwargs = {} + if roc_auc is not None and name is not None: + line_kwargs["label"] = f"{name} (AUC = {roc_auc:0.2f})" + elif roc_auc is not None: + line_kwargs["label"] = f"AUC = {roc_auc:0.2f}" + elif name is not None: + line_kwargs["label"] = name + + line_kwargs.update(**kwargs) + + ax.plot(fpr.detach().cpu(), tpr.detach().cpu(), **line_kwargs) + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index f9d251c49ea..3c6969fb8ad 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -34,19 +34,28 @@ from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.audio.pit import permutation_invariant_training from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy +from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, multilabel_confusion_matrix, ) +from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis -from torchmetrics.utilities.plot import plot_confusion_matrix, plot_single_or_multi_val +from torchmetrics.utilities.plot import plot_binary_roc_curve, plot_confusion_matrix, plot_single_or_multi_val @pytest.mark.parametrize( ("metric", "preds", "target"), [ + # Accuracy + pytest.param( + binary_accuracy, + lambda: torch.rand(100), + lambda: torch.randint(2, (100,)), + id="binary", + ), pytest.param(binary_accuracy, lambda: torch.rand(100), lambda: torch.randint(2, (100,)), id="binary"), pytest.param( partial(multiclass_accuracy, num_classes=3), @@ -60,6 +69,25 @@ lambda: torch.randint(3, (100,)), id="multiclass and average=None", ), + # AUROC + pytest.param( + binary_auroc, + lambda: torch.nn.functional.softmax(torch.randn(100, 2), dim=1)[:, 1], + lambda: torch.randint(2, (100,)), + id="binary", + ), + pytest.param( + partial(multiclass_auroc, num_classes=3), + lambda: torch.randn(100, 3), + lambda: torch.randint(3, (100,)), + id="multiclass", + ), + pytest.param( + partial(multiclass_auroc, num_classes=3, average=None), + lambda: torch.randn(100, 3), + lambda: torch.randint(3, (100,)), + id="multiclass and average=None", + ), pytest.param( partial(spectral_distortion_index), lambda: torch.rand([16, 3, 16, 16]), @@ -235,3 +263,21 @@ def test_confusion_matrix_plotter_with_labels(metric, preds, target, labels): cond1 = isinstance(axs, matplotlib.axes.Axes) cond2 = isinstance(axs, np.ndarray) and all(isinstance(a, matplotlib.axes.Axes) for a in axs) assert cond1 or cond2 + + +@pytest.mark.parametrize( + ("metric", "preds", "target"), + [ + pytest.param( + binary_roc, + lambda: torch.nn.functional.softmax(torch.randn(100, 2), dim=1)[:, 1], + lambda: torch.randint(2, (100,)), + id="binary", + ) + ], +) +def test_binary_roc_curve_plotter(metric, preds, target): + tpr, fpr, thresholds = metric(preds(), target()) + fig, ax = plot_binary_roc_curve(tpr, fpr) + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes)