Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] PrecisionRecallCurve, ROC and AveragePrecision class interface #4549

Merged
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added optimizer refactors ([#4658](https://github.com/PyTorchLightning/pytorch-lightning/pull/4658))


- Added `PrecisionRecallCurve, ROC, AveragePrecision` class metric ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549))


### Changed

Borda marked this conversation as resolved.
Show resolved Hide resolved
- Added custom `Apex` and `NativeAMP` as `Precision plugins` ([#4355](https://github.com/PyTorchLightning/pytorch-lightning/pull/4355))


Expand All @@ -72,9 +77,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549))



- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))



- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648))


Expand Down
33 changes: 22 additions & 11 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,24 @@ ConfusionMatrix
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
:noindex:

AveragePrecision
~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:

ROC
~~~

.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:

Regression Metrics
------------------

Expand Down Expand Up @@ -326,7 +344,7 @@ multiclass_auroc [func]
average_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.average_precision
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
:noindex:


Expand Down Expand Up @@ -365,10 +383,10 @@ iou [func]
:noindex:


multiclass_roc [func]
roc [func]
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.multiclass_roc
.. autofunction:: pytorch_lightning.metrics.functional.roc
:noindex:


Expand All @@ -389,7 +407,7 @@ precision_recall [func]
precision_recall_curve [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall_curve
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve
:noindex:


Expand All @@ -400,13 +418,6 @@ recall [func]
:noindex:


roc [func]
~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.roc
:noindex:


stat_scores [func]
~~~~~~~~~~~~~~~~~~

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
Accuracy,
Precision,
Recall,
ConfusionMatrix,
PrecisionRecallCurve,
AveragePrecision,
ROC,
FBeta,
F1,
ConfusionMatrix
)

from pytorch_lightning.metrics.regression import (
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision
from pytorch_lightning.metrics.classification.roc import ROC
130 changes: 130 additions & 0 deletions pytorch_lightning/metrics/classification/average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Any, Union, List

import torch

from pytorch_lightning.metrics import Metric
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.metrics.functional.average_precision import (
_average_precision_update,
_average_precision_compute
)


class AveragePrecision(Metric):
"""
Computes the average precision score, which summarises the precision recall
curve into one number. Works for both binary and multiclass problems.
In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.

Forward accepts

- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass)
where C is the number of classes

- ``target`` (long tensor): ``(N, ...)``

Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example (binary case):

>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = AveragePrecision(pos_label=1)
>>> average_precision(pred, target)
tensor(1.)

Example (multiclass case):

>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = AveragePrecision(num_classes=4)
>>> average_precision(pred, target)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500)]

"""
def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)

self.num_classes = num_classes
self.pos_label = pos_label

self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `AveragePrecision` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
preds, target, num_classes, pos_label = _average_precision_update(
preds,
target,
self.num_classes,
self.pos_label
)
self.preds.append(preds)
self.target.append(target)
self.num_classes = num_classes
self.pos_label = pos_label

def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Compute the average precision score

Returns:
tensor with average precision. If multiclass will return list
of such tensors, one for each class

"""
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)
return _average_precision_compute(preds, target, self.num_classes, self.pos_label)
150 changes: 150 additions & 0 deletions pytorch_lightning/metrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Any, Union, Tuple, List

import torch

from pytorch_lightning.metrics import Metric
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.metrics.functional.precision_recall_curve import (
_precision_recall_curve_update,
_precision_recall_curve_compute
)


class PrecisionRecallCurve(Metric):
"""
Computes precision-recall pairs for different thresholds. Works for both
binary and multiclass problems. In the case of multiclass, the values will
be calculated based on a one-vs-the-rest approach.

Forward accepts

- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass)
where C is the number of classes

- ``target`` (long tensor): ``(N, ...)``

Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example (binary case):

>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pr_curve = PrecisionRecallCurve(pos_label=1)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])

Example (multiclass case):

>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = PrecisionRecallCurve(num_classes=4)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([0.8500]), tensor([0.8500]), tensor([0.0500, 0.8500]), tensor([0.0500, 0.8500])]

"""
def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)

self.num_classes = num_classes
self.pos_label = pos_label

self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `PrecisionRecallCurve` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
preds, target, num_classes, pos_label = _precision_recall_curve_update(
preds,
target,
self.num_classes,
self.pos_label
)
self.preds.append(preds)
self.target.append(target)
self.num_classes = num_classes
self.pos_label = pos_label

def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
"""
Compute the precision-recall curve

Returns: 3-element tuple containing

precision:
tensor where element i is the precision of predictions with
score >= thresholds[i] and the last element is 1.
If multiclass, this is a list of such tensors, one for each class.
recall:
tensor where element i is the recall of predictions with
score >= thresholds[i] and the last element is 0.
If multiclass, this is a list of such tensors, one for each class.
thresholds:
Thresholds used for computing precision/recall scores

"""
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)
return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label)
Loading