Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot functionality to AUROC and ROC metrics #1490

Merged
merged 30 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d72b078
Changes from executing make test
alexkrz Feb 5, 2023
8686126
Added plot option for MulticlassAUROC
alexkrz Feb 6, 2023
8b094c6
Added plot option for BinaryAUROC and BinaryROC
alexkrz Feb 6, 2023
3b6bdcc
Fixed wrong paranthesis for building docs
alexkrz Feb 6, 2023
2627972
Removed comment from ROC docstring
alexkrz Feb 7, 2023
cd81f63
Merge branch 'master' into feature/roc-curve
alexkrz Feb 7, 2023
74d2128
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2023
ca0c688
Apply suggestions from code review
Borda Feb 8, 2023
11ea836
Update requirements/image.txt
alexkrz Feb 12, 2023
eaba765
Turned compute_auroc() in BinaryROC into private method
alexkrz Feb 12, 2023
8a9fe19
Added plot function to MultilabelAUROC
alexkrz Feb 12, 2023
8c2356a
Removed notebooks/ directory from gitignore
alexkrz Feb 12, 2023
15a5d9d
Added tests for AUROC
alexkrz Feb 12, 2023
ea6d2ce
Added test for binary roc curve
alexkrz Feb 12, 2023
bd2bf1a
Merge branch 'master' into feature/roc-curve
Borda Feb 20, 2023
2deb649
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2023
807dcdf
__doctest_skip__
Borda Feb 20, 2023
ceb13e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2023
61b7354
precommit
Borda Feb 20, 2023
f985712
Merge branch 'feature/roc-curve' of https://github.com/alexkrz/torchm…
Borda Feb 20, 2023
ca0ffd4
Merge branch 'master' into feature/roc-curve
Borda Feb 21, 2023
54cb981
Merge branch 'master' into feature/roc-curve
Borda Feb 22, 2023
0170110
Merge branch 'master' into feature/roc-curve
Borda Feb 22, 2023
aa80cd1
Merge branch 'master' into feature/roc-curve
Borda Feb 22, 2023
60d14de
Merge branch 'master' into feature/roc-curve
Borda Feb 22, 2023
7ea2e80
Any
Borda Feb 22, 2023
33041d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 22, 2023
bfd01a7
Merge branch 'master' into feature/roc-curve
SkafteNicki Feb 24, 2023
df76f1f
changelog
SkafteNicki Feb 24, 2023
fc4b1c9
Merge branch 'master' into feature/roc-curve
mergify[bot] Feb 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for plotting of metrics through `.plot()` method (
[#1328](https://github.com/Lightning-AI/metrics/pull/1328),
[#1481](https://github.com/Lightning-AI/metrics/pull/1481),
[#1480](https://github.com/Lightning-AI/metrics/pull/1480)
[#1480](https://github.com/Lightning-AI/metrics/pull/1480),
[#1490](https://github.com/Lightning-AI/metrics/pull/1490),
)


Expand Down
119 changes: 118 additions & 1 deletion src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -32,6 +32,11 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"]


class BinaryAUROC(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -93,6 +98,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0}

def __init__(
self,
Expand All @@ -112,6 +118,42 @@ def compute(self) -> Tensor:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_auroc_compute(state, self.thresholds, self.max_fpr)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> from torch import randn, randint
>>> import torch.nn.functional as F
>>> # Example plotting a combined value across all classes
>>> from torchmetrics.classification import BinaryAUROC
>>> preds = F.softmax(randn(20, 2), dim=1)
>>> target = randint(2, (20,))
>>> metric = BinaryAUROC()
>>> metric.update(preds[:, 1], target)
>>> fig_, ax_ = metric.plot()
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class MulticlassAUROC(MulticlassPrecisionRecallCurve):
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC
Expand Down Expand Up @@ -189,6 +231,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_options = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Class"}

def __init__(
self,
Expand All @@ -212,6 +255,39 @@ def compute(self) -> Tensor:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> from torch import randn, randint
>>> # Example plotting a combined value across all classes
>>> from torchmetrics.classification import MulticlassAUROC
>>> metric = MulticlassAUROC(num_classes=3, average="macro")
>>> metric.update(randn(20, 3), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class MultilabelAUROC(MultilabelPrecisionRecallCurve):
alexkrz marked this conversation as resolved.
Show resolved Hide resolved
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. The AUROC
Expand Down Expand Up @@ -291,6 +367,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_options = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Class"}

def __init__(
self,
Expand All @@ -314,6 +391,46 @@ def compute(self) -> Tensor:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelAUROC
>>> preds = tensor([[0.75, 0.05, 0.35],
... [0.45, 0.75, 0.05],
... [0.05, 0.55, 0.75],
... [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
... [0, 0, 0],
... [0, 1, 1],
... [1, 1, 1]])
>>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=None)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class AUROC:
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). The AUROC score summarizes the
Expand Down
62 changes: 61 additions & 1 deletion src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -21,6 +21,14 @@
MulticlassPrecisionRecallCurve,
MultilabelPrecisionRecallCurve,
)
from torchmetrics.functional.classification.auroc import (
_binary_auroc_arg_validation,
_binary_auroc_compute,
_multiclass_auroc_arg_validation,
_multiclass_auroc_compute,
_multilabel_auroc_arg_validation,
_multilabel_auroc_compute,
)
from torchmetrics.functional.classification.roc import (
_binary_roc_compute,
_multiclass_roc_compute,
Expand All @@ -29,6 +37,11 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_binary_roc_curve

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BinaryROC.plot"]


class BinaryROC(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -107,6 +120,53 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_roc_compute(state, self.thresholds)

def __compute_auroc(self) -> Tensor:
"""Computes Area under ROC curve from BinaryAUROC metric to show AUROC value together with ROC plot."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_auroc_compute(state, self.thresholds, max_fpr=None)

def plot(
self,
fpr: Optional[Union[Tensor, Sequence[Tensor]]] = None,
tpr: Optional[Union[Tensor, Sequence[Tensor]]] = None,
ax: Optional[_AX_TYPE] = None,
name: Optional[str] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
fpr: False Postive Rate provided by calling `metric.forward` or `metric.compute`
tpr: True Postive Rate provided by calling `metric.forward` or `metric.compute`
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
name: Custom name to describe the classifier

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> from torch import randn, randint
>>> import torch.nn.functional as F
>>> from torchmetrics.classification import BinaryROC
>>> preds = F.softmax(randn(20, 2), dim=1)
>>> target = randint(2, (20,))
>>> metric = BinaryROC()
>>> metric.update(preds[:, 1], target)
>>> fig_, ax_ = metric.plot()
"""
if fpr is None or tpr is None:
fpr, tpr, _ = self.compute()
roc_auc = self.__compute_auroc()
name = self.__class__.__name__ if name is None else name
fig, ax = plot_binary_roc_curve(tpr, fpr, ax=ax, roc_auc=roc_auc, name=name)
return fig, ax


class MulticlassROC(MulticlassPrecisionRecallCurve):
r"""Compute the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple pairs of
Expand Down
53 changes: 52 additions & 1 deletion src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from itertools import product
from math import ceil, floor, sqrt
from typing import List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -213,3 +213,54 @@ def plot_confusion_matrix(
ax.text(jj, ii, str(val.item()), ha="center", va="center", fontsize=15)

return fig, axs


def plot_binary_roc_curve(
tpr: Tensor,
fpr: Tensor,
ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
roc_auc: Optional[Union[float, Tensor]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> _PLOT_OUT_TYPE:
"""Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py.

Plots the roc curve

Args:
tpr: Tensor containing the values for True Positive Rate
fpr: Tensor containing the values for False Positive Rate
ax: Axis from a figure
roc_auc: AUROC score (computed separately)
name: Custom name to describe the classifier
kwargs: additional keyword arguments for line drawing

Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
"""
_error_on_missing_matplotlib()
fig, ax = plt.subplots() if ax is None else (None, ax)

if isinstance(roc_auc, Tensor):
assert roc_auc.numel() == 1, "roc_auc Tensor must consist of only one element"
roc_auc = roc_auc.item()

line_kwargs = {}
if roc_auc is not None and name is not None:
line_kwargs["label"] = f"{name} (AUC = {roc_auc:0.2f})"
elif roc_auc is not None:
line_kwargs["label"] = f"AUC = {roc_auc:0.2f}"
elif name is not None:
line_kwargs["label"] = name

line_kwargs.update(**kwargs)

ax.plot(fpr.detach().cpu(), tpr.detach().cpu(), **line_kwargs)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")

return fig, ax
Loading