From 6959ea03c553a6753cc973abb58f240a096bc1c1 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 20:48:57 +0100 Subject: [PATCH 01/61] Add stuff --- .../metrics/classification/utils.py | 385 ++++++++++++++++++ pytorch_lightning/metrics/utils.py | 57 ++- tests/metrics/classification/inputs.py | 18 +- tests/metrics/classification/test_inputs.py | 301 ++++++++++++++ 4 files changed, 738 insertions(+), 23 deletions(-) create mode 100644 pytorch_lightning/metrics/classification/utils.py create mode 100644 tests/metrics/classification/test_inputs.py diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py new file mode 100644 index 0000000000000..b8f5af2e988d8 --- /dev/null +++ b/pytorch_lightning/metrics/classification/utils.py @@ -0,0 +1,385 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Optional + +import numpy as np +import torch + +from pytorch_lightning.metrics.utils import to_onehot, select_topk + + +def _check_classification_inputs( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float, + num_classes: Optional[int] = None, + is_multiclass: bool = False, + top_k: int = 1, +) -> None: + """Performs error checking on inputs for classification. + + This ensures that preds and target take one of the shape/type combinations that are + specified in ``_input_format_classification`` docstring. It also checks the cases of + over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class + cases) that there are only up to 2 distinct labels. + + In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. + + When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, + multi-label, ...), and that, if availible, the implied number of classes in the ``C`` + dimension is consistent with it (as well as that max label in target is smaller than it). + + When ``num_classes`` is not specified in these cases, consistency of the highest target + value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. + + If ``top_k`` is larger than one, then an error is raised if the inputs are not (multi-dim) + multi-class with probability predictions. + + Preds and target tensors are expected to be squeezed already - all dimensions should be + greater than 1, except perhaps the first one (N). + + Args: + preds: tensor with predictions + target: tensor with ground truth labels, always integers + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: number of classes + is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim + multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim + multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. + Defaults to None, which treats inputs as they appear. + """ + + if target.is_floating_point(): + raise ValueError("target has to be an integer tensor") + elif target.min() < 0: + raise ValueError("target has to be a non-negative tensor") + + preds_float = preds.is_floating_point() + if not preds_float and preds.min() < 0: + raise ValueError("if preds are integers, they have to be non-negative") + + if not preds.shape[0] == target.shape[0]: + raise ValueError("preds and target should have the same first dimension.") + + if preds_float: + if preds.min() < 0 or preds.max() > 1: + raise ValueError( + "preds should be probabilities, but values were detected outside of [0,1] range" + ) + + if threshold > 1 or threshold < 0: + raise ValueError("Threshold should be a probability in [0,1]") + + if is_multiclass is False and target.max() > 1: + raise ValueError("If you set is_multiclass=False, then target should not exceed 1.") + + if is_multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") + + # Check that shape/types fall into one of the cases + if len(preds.shape) == len(target.shape): + if preds.shape != target.shape: + raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") + if preds_float and target.max() > 1: + raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") + + # Get the case + if len(preds.shape) == 1 and preds_float: + case = "binary" + elif len(preds.shape) == 1 and not preds_float: + case = "multi-class" + elif len(preds.shape) > 1 and preds_float: + case = "multi-label" + else: + case = "multi-dim multi-class" + + implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) + + elif len(preds.shape) == len(target.shape) + 1: + if not preds_float: + raise ValueError("if preds have one dimension more than target, preds should be a float tensor") + if not preds.shape[:-1] == target.shape: + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "if preds if preds have one dimension more than target, the shape of preds should be" + "either of shape (N, C, ...) or (N, ..., C), and of targets of shape (N, ...)" + ) + + extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] + + if len(preds.shape) == 2: + case = "multi-class" + else: + case = "multi-dim multi-class" + else: + raise ValueError( + "preds and target should both have the (same) shape (N, ...), or target (N, ...)" + " and preds (N, C, ...) or (N, ..., C)" + ) + + if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: + raise ValueError( + "You have set is_multiclass=False, but have more than 2 classes in your data," + " based on the C dimension of preds." + ) + + # Check that num_classes is consistent + if not num_classes: + if preds.shape != target.shape and target.max() >= extra_dim_size: + raise ValueError("The highest label in targets should be smaller than the size of C dimension") + else: + if case == "binary": + if num_classes > 2: + raise ValueError("Your data is binary, but num_classes is larger than 2.") + elif num_classes == 2 and not is_multiclass: + raise ValueError( + "Your data is binary and num_classes=2, but is_multiclass is not True." + "Set it to True if you want to transform binary data to multi-class format." + ) + elif num_classes == 1 and is_multiclass: + raise ValueError( + "You have binary data and have set is_multiclass=True, but num_classes is 1." + "Either leave is_multiclass unset or set it to 2 to transform binary data to multi-class format." + ) + elif "multi-class" in case: + if num_classes == 1 and is_multiclass is not False: + raise ValueError( + "You have set num_classes=1, but predictions are integers." + "If you want to convert (multi-dimensional) multi-class data with 2 classes" + "to binary/multi-label, set is_multiclass=False." + ) + elif num_classes > 1: + if is_multiclass is False: + if implied_classes != num_classes: + raise ValueError( + "You have set is_multiclass=False, but the implied number of classes " + "(from shape of inputs) does not match num_classes. If you are trying to" + "transform multi-dim multi-class data with 2 classes to multi-label, num_classes" + "should be either None or the product of the size of extra dimensions (...)." + "See Input Types in Metrics documentation." + ) + if num_classes <= target.max(): + raise ValueError("The highest label in targets should be smaller than num_classes") + if num_classes <= preds.max(): + raise ValueError("The highest label in preds should be smaller than num_classes") + if preds.shape != target.shape and num_classes != extra_dim_size: + raise ValueError("The size of C dimension of preds does not match num_classes") + + elif case == "multi-label": + if is_multiclass and num_classes != 2: + raise ValueError( + "Your have set is_multiclass=True, but num_classes is not equal to 2." + "If you are trying to transform multi-label data to 2 class multi-dimensional" + "multi-class, you should set num_classes to either 2 or None." + ) + if not is_multiclass and num_classes != implied_classes: + raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + + # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities + if top_k > 1: + if preds.shape == target.shape: + raise ValueError( + "You have set top_k above 1, but your data is not (multi-dimensional) multi-class" + "with probability predictions." + ) + + +def _input_format_classification( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + top_k: int = 1, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor, str]: + """Convert preds and target tensors into common format. + + Preds and targets are supposed to fall into one of these categories (and are + validated to make sure this is the case): + + * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) + * Both preds and target are of shape ``(N,)``, and target is binary, while preds + are a float (binary) + * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and + is integer (multi-class) + * preds and target are of shape ``(N, ...)``, target is binary and preds is a float + (multi-label) + * preds are of shape ``(N, ..., C)`` or ``(N, C, ...)`` and are floats, target is of + shape ``(N, ...)`` and is integer (multi-dimensional multi-class) + * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional + multi-class) + + To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. + + The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` + of ``(N, C, X)``, the details for each case are described below. The function also returns + a ``mode`` string, which describes which of the above cases the inputs belonged to - regardless + of whether this was "overridden" by other settings (like ``is_multiclass``). + + In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed + into a binary tensor (elements become 1 if the probability is greater than or equal to + ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds + become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to + preds first. + + In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets + by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original + shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be + returned as ``(N,1)`` tensor. + + In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with + preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening + all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as + ``(N, 2, C)``, by an equivalent transformation as in the binary case. + + In multi-dimensional multi-class case, normally both target and preds are returned as + ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and + ``C``. The transformations performed here are equivalent to the multi-class case. However, if + ``is_multiclass=False`` (and there are up to two classes), then the data is returned as + ``(N, X)`` binary tensors (multi-label). + + Also, in multi-dimensional multi-class case, if the position of the ``C`` + dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a + ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. + If this is not the case, you should move it from the last to second place using + ``torch.movedim(preds, -1, 1)``. + + Note that where a one-hot transformation needs to be performed and the number of classes + is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be + equal to ``num_classes``, if it is given, or the maximum label value in preds and + target. + + Args: + preds: tensor with predictions + target: tensor with ground truth labels, always integers + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: number of classes + top_k: number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class cases. + is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim + multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim + multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. + Defaults to None, which treats inputs as they appear. + + Returns: + preds: binary tensor of shape (N, C) or (N, C, X) + target: binary tensor of shape (N, C) or (N, C, X) + """ + preds, target = preds.clone().detach(), target.clone().detach() + + # Remove excess dimensions + if preds.shape[0] == 1: + preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) + else: + preds, target = preds.squeeze(), target.squeeze() + + _check_classification_inputs( + preds, + target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + preds_float = preds.is_floating_point() + + if len(preds.shape) == len(target.shape) == 1 and preds_float: + mode = "binary" + preds = (preds >= threshold).int() + + if is_multiclass: + target = to_onehot(target, 2) + preds = to_onehot(preds, 2) + else: + preds = preds.unsqueeze(-1) + target = target.unsqueeze(-1) + + elif len(preds.shape) == len(target.shape) and preds_float: + mode = "multi-label" + preds = (preds >= threshold).int() + + if is_multiclass: + preds = to_onehot(preds, 2).reshape(preds.shape[0], 2, -1) + target = to_onehot(target, 2).reshape(target.shape[0], 2, -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + + elif len(preds.shape) == len(target.shape) + 1 == 2: + mode = "multi-class" + if not num_classes: + num_classes = preds.shape[1] + + target = to_onehot(target, num_classes) + preds = select_topk(preds, top_k) + + # If is_multiclass=False, force to binary + if is_multiclass is False: + target = target[:, [1]] + preds = preds[:, [1]] + + elif len(preds.shape) == len(target.shape) == 1 and not preds_float: + mode = "multi-class" + + if not num_classes: + num_classes = max(preds.max(), target.max()) + 1 + + # If is_multiclass=False, force to binary + if is_multiclass is False: + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) + else: + preds = to_onehot(preds, num_classes) + target = to_onehot(target, num_classes) + + # Multi-dim multi-class (N, ...) with integers + elif preds.shape == target.shape and not preds_float: + mode = "multi-dim multi-class" + + if not num_classes: + num_classes = max(preds.max(), target.max()) + 1 + + # If is_multiclass=False, force to multi-label + if is_multiclass is False: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + else: + target = to_onehot(target, num_classes) + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = to_onehot(preds, num_classes) + preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + + # Multi-dim multi-class (N, C, ...) and (N, ..., C) + else: + mode = "multi-dim multi-class" + if preds.shape[:-1] == target.shape: + preds = torch.movedim(preds, -1, 1) + + num_classes = preds.shape[1] + + if is_multiclass is False: + target = target.reshape(target.shape[0], -1) + preds = select_topk(preds, 1)[:, 1, ...] + preds = preds.reshape(preds.shape[0], -1) + else: + target = to_onehot(target, num_classes) + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) + + return preds, target, mode diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e1ff95b94f471..1ce56b30cf9e5 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -20,6 +20,7 @@ def dim_zero_cat(x): + x = x if isinstance(x, (list, tuple)) else [x] return torch.cat(x, dim=0) @@ -36,8 +37,8 @@ def _flatten(x): def to_onehot( - tensor: torch.Tensor, - num_classes: int, + tensor: torch.Tensor, + num_classes: int, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format @@ -57,24 +58,46 @@ def to_onehot( [0, 0, 0, 1]]) """ dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) +def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: + """ + Convert a probability tensor to binary by selecting top-k highest entries. + + Args: + tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + position defined by the ``dim`` argument + topk: number of highest entries to turn into 1s + dim: dimension on which to compare entries + + Output: + A binary tensor of the same shape as the input tensor of type torch.int32 + + Example: + >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> select_topk(x, topk=2) + tensor([[0, 1, 1], + [1, 1, 0]], dtype=torch.int32) + """ + zeros = torch.zeros_like(tensor, device=tensor.device) + topk_tensor = zeros.scatter(1, tensor.topk(k=topk, dim=dim).indices, 1.0) + + return topk_tensor.int() + + def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): """ Check that predictions and target have the same shape, else raise error """ if pred.shape != target.shape: - raise RuntimeError('Predictions and targets are expected to have the same shape') + raise RuntimeError("Predictions and targets are expected to have the same shape") def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5 + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into label tensors + """Convert preds and target tensors into label tensors Args: preds: either tensor with labels, tensor with probabilities/logits or @@ -87,9 +110,7 @@ def _input_format_classification( target: tensor with labels """ if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") if len(preds.shape) == len(target.shape) + 1: # multi class probabilites @@ -102,13 +123,9 @@ def _input_format_classification( def _input_format_classification_one_hot( - num_classes: int, - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - multilabel: bool = False + num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, multilabel: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into one hot spare label tensors + """Convert preds and target tensors into one hot spare label tensors Args: num_classes: number of classes @@ -123,9 +140,7 @@ def _input_format_classification_one_hot( target: one hot tensors of shape [num_classes, -1] with true labels """ if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") if len(preds.shape) == len(target.shape) + 1: # multi class probabilites diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index 9613df3b6f8ca..e648aaf10093e 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -29,12 +29,21 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) _multilabel_inputs = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) @@ -61,8 +70,13 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) +# Class dimension last +_multidim_multiclass_prob_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) _multidim_multiclass_inputs = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)) + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py new file mode 100644 index 0000000000000..8d17d5624fac0 --- /dev/null +++ b/tests/metrics/classification/test_inputs.py @@ -0,0 +1,301 @@ +import pytest +import torch +from torch import randint, rand + +from pytorch_lightning.metrics.utils import to_onehot, select_topk +from pytorch_lightning.metrics.classification.utils import _input_format_classification +from tests.metrics.classification.inputs import ( + Input, + _binary_inputs as _bin, + _binary_prob_inputs as _bin_prob, + _multiclass_inputs as _mc, + _multiclass_prob_inputs as _mc_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, + _multidim_multiclass_prob_inputs1 as _mdmc_prob1, + _multilabel_inputs as _ml, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_inputs as _mlmd, + _multilabel_multidim_prob_inputs as _mlmd_prob, +) +from tests.metrics.utils import NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, THRESHOLD + +torch.manual_seed(42) + +# Some additional inputs to test on +_mc_prob_2cls = Input(rand(NUM_BATCHES, BATCH_SIZE, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) +_mdmc_prob_many_dims = Input( + rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM), + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) +_mdmc_prob_many_dims1 = Input( + rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM, NUM_CLASSES), + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) +_mdmc_prob_2cls = Input( + rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) +_mdmc_prob_2cls1 = Input( + rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) + +# Some utils +T = torch.Tensor + + +def idn(x): + return x + + +def usq(x): + return x.unsqueeze(-1) + + +def toint(x): + return x.int() + + +def thrs(x): + return x >= THRESHOLD + + +def rshp1(x): + return x.reshape(x.shape[0], -1) + + +def rshp2(x): + return x.reshape(x.shape[0], x.shape[1], -1) + + +def onehot(x): + return to_onehot(x, NUM_CLASSES) + + +def onehot2(x): + return to_onehot(x, 2) + + +def top1(x): + return select_topk(x, 1) + + +def top2(x): + return select_topk(x, 2) + + +def mvdim(x): + return torch.movedim(x, -1, 1) + + +# To avoid ugly black line wrapping +def ml_preds_tr(x): + return rshp1(toint(thrs(x))) + + +def onehot_rshp1(x): + return onehot(rshp1(x)) + + +def onehot2_rshp1(x): + return onehot2(rshp1(x)) + + +def top1_rshp2(x): + return top1(rshp2(x)) + + +def top2_rshp2(x): + return top2(rshp2(x)) + + +def mdmc1_top1_tr(x): + return top1(rshp2(mvdim(x))) + + +def mdmc1_top2_tr(x): + return top2(rshp2(mvdim(x))) + + +def probs_to_mc_preds_tr(x): + return toint(onehot2(thrs(x))) + + +def mlmd_prob_to_mc_preds_tr(x): + return onehot2(rshp1(toint(thrs(x)))) + + +def mdmc_prob_to_ml_preds_tr(x): + return top1(mvdim(x))[:, 1] + + +######################## +# Test correct inputs +######################## + + +@pytest.mark.parametrize( + "inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + [ + ############################# + # Test usual expected cases + (_bin, THRESHOLD, None, False, 1, "multi-class", usq, usq), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: usq(toint(thrs(x))), usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: toint(thrs(x)), idn), + (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", idn, idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", ml_preds_tr, rshp1), + (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", rshp1, rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", onehot, onehot), + (_mc_prob, THRESHOLD, None, None, 1, "multi-class", top1, onehot), + (_mc_prob, THRESHOLD, None, None, 2, "multi-class", top2, onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", onehot, onehot), + (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot), + (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot_rshp1), + # Test with C dim in last place + (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot), + (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot_rshp1), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot_rshp1), + ########################### + # Test some special cases + # Binary as multiclass + (_bin, THRESHOLD, None, None, 1, "multi-class", onehot2, onehot2), + # Binary probs as multiclass + (_bin_prob, THRESHOLD, None, True, 1, "binary", probs_to_mc_preds_tr, onehot2), + # Multilabel as multiclass + (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2, onehot2), + # Multilabel probs as multiclass + (_ml_prob, THRESHOLD, None, True, 1, "multi-label", probs_to_mc_preds_tr, onehot2), + # Multidim multilabel as multiclass + (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2_rshp1, onehot2_rshp1), + # Multidim multilabel probs as multiclass + (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", mlmd_prob_to_mc_preds_tr, onehot2_rshp1), + # Multiclass prob with 2 classes as binary + (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: top1(x)[:, [1]], usq), + # Multi-dim multi-class with 2 classes as multi-label + (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: top1(x)[:, 1], idn), + (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", mdmc_prob_to_ml_preds_tr, idn), + ], +) +def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0], + target=inputs.target[0], + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0])) + assert torch.equal(target_out, post_target(inputs.target[0])) + + # Test that things work when batch_size = 1 + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0][[0], ...], + target=inputs.target[0][[0], ...], + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...])) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...])) + + +# Test that threshold is correctly applied +def test_threshold(): + target = T([1, 1, 1]).int() + preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5]) + + preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) + + assert torch.equal(torch.tensor([0, 1, 1]), preds_probs_out.squeeze().long()) + + +######################################################################## +# Test incorrect inputs +######################################################################## + + +@pytest.mark.parametrize( + "preds, target, threshold, num_classes, is_multiclass, top_k", + [ + # Target not integer + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, 1), + # Target negative + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, 1), + # Preds negative integers + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Negative probabilities + (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Threshold outside of [0,1] + (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, 1), + # is_multiclass=False and target > 1 + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, 1), + # is_multiclass=False and preds integers with > 1 + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, 1), + # Wrong batch size + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Completely wrong shape + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + # Same #dims, different shape + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + # Same shape and preds floats, target not binary + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, 1), + # #dims in preds = 1 + #dims in target, C shape not second or last + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + # #dims in preds = 1 + #dims in target, preds not float + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + # is_multiclass=False, with C dimension > 2 + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, 1), + # Max target larger or equal to C dimension + (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, 1), + # C dimension not equal to num_classes + (rand(size=(7, 3, 4)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), + # Max target larger than num_classes (with #dim preds = 1 + #dims target) + (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + # Max target larger than num_classes (with #dim preds = #dims target) + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + # Max preds larger than num_classes (with #dim preds = #dims target) + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, 1), + # Num_classes=1, but is_multiclass not false + (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, 1, None, 1), + # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + # Multilabel input with implied class dimension != num_classes + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, 1), + # Binary input, num_classes > 2 + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, 1), + # Binary input, num_classes == 2 and is_multiclass not True + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, 1), + # Binary input, num_classes == 1 and is_multiclass=True + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, 1), + # Topk > 1 with non (md)mc prob data + (_bin.preds[0], _bin.target[0], 0.5, None, None, 2), + (_bin_prob.preds[0], _bin_prob.target[0], 0.5, None, None, 2), + (_mc.preds[0], _mc.target[0], 0.5, None, None, 2), + (_ml.preds[0], _ml.target[0], 0.5, None, None, 2), + (_mlmd.preds[0], _mlmd.target[0], 0.5, None, None, 2), + (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), + (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), + (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), + ], +) +def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, + target=target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) From 06790157ed025ba98950b9616ef43835f9b6c4e7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 21:47:30 +0100 Subject: [PATCH 02/61] Change metrics documentation layout --- docs/source/metrics.rst | 188 ++++++++++++++++++++++++++-------------- 1 file changed, 121 insertions(+), 67 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d47c872f35047..407b64d3d2948 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -196,12 +196,71 @@ Metric API .. autoclass:: pytorch_lightning.metrics.Metric :noindex: -************* -Class metrics -************* +*************************** +Class vs Functional Metrics +*************************** +The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. + +Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. +If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. + +********************** Classification Metrics ----------------------- +********************** + +Input types +----------- + +For the purposes of classification metrics, inputs (predictions and targets) are split +into these categories (``N`` stands for the batch size and ``C`` for number of classes): + +.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 + :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype" + :widths: 20, 10, 10, 10, 10 + + "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" + "Multi-class", "(N,)", "``int``", "(N,)", "``int``" + "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" + "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" + "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with probabilities", "(N, C, ...) or (N, ..., C)", "``float``", "(N, ...)", "``int``" + +.. note:: + All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so + that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. + +When predictions or targets are integers, it is assumed that class labels start at , i.e. +the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types + +.. code-block:: python + + # Binary inputs + binary_preds = torch.tensor([0.6, 0.1, 0.9]) + binary_target = torch.tensor([1, 0, 2]) + + # Multi-class inputs + mc_preds = torch.tensor([0, 2, 1]) + mc_target = torch.tensor([0, 1, 2]) + + # Multi-class inputs with probabilities + mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) + mc_target_probs = torch.tensor([0, 1, 2]) + + # Multi-label inputs + ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) + ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) + +In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class, +but are actually binary/multi-label. For example, if both predictions and targets are 1d +binary tensors. Or it could be the other way around, you want to treat binary/multi-label +inputs as 2-class (multi-dimensional) multi-class inputs. + +For these cases, the metrics where this distinction would make a difference, expose the +``is_multiclass`` argument. + +Class Metrics (Classification) +------------------------------ Accuracy ~~~~~~~~ @@ -239,61 +298,8 @@ ConfusionMatrix .. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix :noindex: -Regression Metrics ------------------- - -MeanSquaredError -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError - :noindex: - - -MeanAbsoluteError -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError - :noindex: - - -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError - :noindex: - - -ExplainedVariance -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance - :noindex: - - -PSNR -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.PSNR - :noindex: - - -SSIM -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.SSIM - :noindex: - -****************** -Functional Metrics -****************** - -The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. - -Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. - -Classification --------------- +Functional Metrics (Classification) +----------------------------------- accuracy [func] ~~~~~~~~~~~~~~~ @@ -434,9 +440,57 @@ to_onehot [func] .. autofunction:: pytorch_lightning.metrics.functional.classification.to_onehot :noindex: +****************** +Regression Metrics +****************** + +Class Metrics (Regression) +-------------------------- + +MeanSquaredError +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError + :noindex: + + +MeanAbsoluteError +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError + :noindex: + + +MeanSquaredLogError +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError + :noindex: + + +ExplainedVariance +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance + :noindex: + + +PSNR +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.PSNR + :noindex: + -Regression ----------- +SSIM +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.SSIM + :noindex: + + +Functional Metrics (Regression) +------------------------------- explained_variance [func] ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -479,22 +533,22 @@ ssim [func] .. autofunction:: pytorch_lightning.metrics.functional.ssim :noindex: - +*** NLP ---- +*** bleu_score [func] -~~~~~~~~~~~~~~~~~ +----------------- .. autofunction:: pytorch_lightning.metrics.functional.nlp.bleu_score :noindex: - +******** Pairwise --------- +******** embedding_similarity [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +--------------------------- .. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity :noindex: From 35627b50092df312ed71d9ab1c828a2451306d5d Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:07:40 +0100 Subject: [PATCH 03/61] Add stuff --- docs/source/metrics.rst | 13 +- pytorch_lightning/metrics/__init__.py | 1 + .../metrics/classification/__init__.py | 2 +- .../metrics/classification/accuracy.py | 162 +++++++++++++--- .../metrics/functional/__init__.py | 3 +- .../metrics/functional/accuracy.py | 155 +++++++++++++++ tests/metrics/classification/test_accuracy.py | 181 +++++++++++------- tests/metrics/utils.py | 153 +++++++-------- 8 files changed, 494 insertions(+), 176 deletions(-) create mode 100644 pytorch_lightning/metrics/functional/accuracy.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 407b64d3d2948..e7f89ba1b853b 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -268,6 +268,12 @@ Accuracy .. autoclass:: pytorch_lightning.metrics.classification.Accuracy :noindex: +Hamming Loss +~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.HammingLoss + :noindex: + Precision ~~~~~~~~~ @@ -304,9 +310,14 @@ Functional Metrics (Classification) accuracy [func] ~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.accuracy +.. autofunction:: pytorch_lightning.metrics.functional.accuracy :noindex: +hamming_loss [func] +~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.hamming_loss + :noindex: auc [func] ~~~~~~~~~~ diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 653ad23c68f7e..d10aa2e5e995c 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -15,6 +15,7 @@ from pytorch_lightning.metrics.classification import ( Accuracy, + HammingLoss, Precision, Recall, FBeta, diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index db643c227abed..45d4dd03e430e 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics.classification.accuracy import Accuracy, HammingLoss from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 0f01fb9813407..8297cc73f9540 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -11,38 +11,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -import functools -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union -from collections.abc import Mapping, Sequence -from collections import namedtuple +from typing import Any, Callable, Optional import torch -from torch import nn from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.utils import _input_format_classification +from pytorch_lightning.metrics.functional.accuracy import ( + _accuracy_update, + _hamming_loss_update, + _accuracy_compute, + _hamming_loss_compute, +) class Accuracy(Metric): """ - Computes accuracy. Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. + Computes the share of entirely correctly predicted samples. - Forward accepts + This metric generalizes to subset accuracy for multilabel data, and similarly for + multi-dimensional multi-class data: for the sample to be counted as correct, the the + class has to be correctly predicted across all extra dimension for each sample in the + ``N`` dimension. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` + is this is not what you want. - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + Accepts all input types listed in :ref:`metrics:Input types`. Args: + top_k: + Number of highest probability predictions considered to find the correct label, for + (multi-dimensional) multi-class inputs with probability predictions. Default 1 + + If your inputs are not (multi-dimensional) multi-class inputs with probability predictions, + an error will be raised if ``top_k`` is set to a value other than 1. + mdmc_accuracy: + Determines how should the extra dimension be handeled in case of multi-dimensional multi-class + inputs. Options are ``"global"`` or ``"subset"``. + + If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension + were unrolled into a new sample dimension. + + If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the + ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension + must be predicted correctly (the ``top_k`` option still applies here). threshold: - Threshold value for binary or multi-label logits. default: 0.5 + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -63,7 +79,97 @@ class Accuracy(Metric): >>> accuracy(preds, target) tensor(0.5000) + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy = Accuracy(top_k=2) + >>> accuracy(preds, target) + tensor(0.6667) + """ + + def __init__( + self, + top_k: int = 1, + mdmc_accuracy: str = "subset", + threshold: float = 0.5, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + self.threshold = threshold + self.top_k = top_k + self.mdmc_accuracy = mdmc_accuracy + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + """ + + correct, total = _accuracy_update(preds, target, self.threshold, self.top_k, self.mdmc_accuracy) + + self.correct += correct + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes accuracy based on inputs passed in to ``update`` previously. + """ + return _accuracy_compute(self.correct, self.total) + + +class HammingLoss(Metric): + """ + Computes the share of wrongly predicted labels. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. If this is not what you want, consider using + :class:`~pytorch_lightning.metrics.classification.Accuracy`. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + Example: + + >>> from pytorch_lightning.metrics import HammingLoss + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_loss = HammingLoss() + >>> hamming_loss(preds, target) + tensor(0.2500) + + """ + def __init__( self, threshold: float = 0.5, @@ -86,20 +192,20 @@ def __init__( def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. Args: - preds: Predictions from model + preds: Predictions from model (probabilities, or labels) target: Ground truth values """ - preds, target = _input_format_classification(preds, target, self.threshold) - assert preds.shape == target.shape + correct, total = _hamming_loss_update(preds, target, self.threshold) - self.correct += torch.sum(preds == target) - self.total += target.numel() + self.correct += correct + self.total += total - def compute(self): + def compute(self) -> torch.Tensor: """ - Computes accuracy over state. + Computes hamming loss based on inputs passed in to ``update`` previously. """ - return self.correct.float() / self.total + return _hamming_loss_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 3bb5313db7b27..42029335afe9f 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.metrics.functional.classification import ( - accuracy, auc, auroc, average_precision, @@ -42,5 +41,7 @@ from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error from pytorch_lightning.metrics.functional.psnr import psnr from pytorch_lightning.metrics.functional.ssim import ssim + +from pytorch_lightning.metrics.functional.accuracy import accuracy, hamming_loss from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py new file mode 100644 index 0000000000000..f9861388cceda --- /dev/null +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -0,0 +1,155 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from pytorch_lightning.metrics.classification.utils import _input_format_classification + +################################ +# Accuracy +################################ + + +def _accuracy_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: int, mdmc_accuracy: str +) -> Tuple[torch.Tensor, torch.Tensor]: + + preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) + + if mode in ["binary", "multi-label"]: + correct = (preds == target).all(dim=1).sum() + total = target.shape[0] + elif mdmc_accuracy == "global": + correct = (preds * target).sum() + total = target.sum() + elif mdmc_accuracy == "subset": + extra_dims = list(range(1, len(preds.shape))) + sample_correct = (preds * target).sum(dim=extra_dims) + sample_total = target.sum(dim=extra_dims) + + correct = (sample_correct == sample_total).sum() + total = target.shape[0] + + return (torch.tensor(correct, device=preds.device), torch.tensor(total, device=preds.device)) + + +def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: + return correct / total + + +def accuracy( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, top_k: int = 1, mdmc_accuracy: str = "subset" +) -> torch.Tensor: + """ + Computes the share of entirely correctly predicted samples. + + This metric generalizes to subset accuracy for multilabel data, and similarly for + multi-dimensional multi-class data: for the sample to be counted as correct, the the + class has to be correctly predicted across all extra dimension for each sample in the + ``N`` dimension. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` + is this is not what you want. + + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + top_k: + Number of highest probability predictions considered to find the correct label, for + (multi-dimensional) multi-class inputs with probability predictions. Default 1 + + If your inputs are not (multi-dimensional) multi-class inputs with probability predictions, + an error will be raised if ``top_k`` is set to a value other than 1. + mdmc_accuracy: + Determines how should the extra dimension be handeled in case of multi-dimensional multi-class + inputs. Options are ``"global"`` or ``"subset"``. + + If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension + were unrolled into a new sample dimension. + + If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the + ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension + must be predicted correctly (the ``top_k`` option still applies here). + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + + Example: + + >>> from pytorch_lightning.metrics.functional import accuracy + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> accuracy(preds, target) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy(preds, target, top_k=2) + tensor(0.6667) + """ + + correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) + return _accuracy_compute(correct, total) + + +################################ +# Hamming loss +################################ + + +def _hamming_loss_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: + preds, target, _ = _input_format_classification(preds, target, threshold=threshold) + + correct = (preds == target).sum() + total = preds.numel() + + return correct, total + + +def _hamming_loss_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: + return 1 - correct.float() / total + + +def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + Computes the share of wrongly predicted labels. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. If this is not what you want, consider using + :class:`~pytorch_lightning.metrics.classification.Accuracy`. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + + Example: + + >>> from pytorch_lightning.metrics.functional import hamming_loss + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_loss(preds, target) + tensor(0.2500) + + """ + + correct, total = _hamming_loss_update(preds, target, threshold) + return _hamming_loss_compute(correct, total) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 017438269bdbf..8d8386683748c 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -1,9 +1,11 @@ import numpy as np import pytest import torch -from sklearn.metrics import accuracy_score +from sklearn.metrics import accuracy_score as sk_accuracy, hamming_loss as sk_hamming_loss -from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics import Accuracy, HammingLoss +from pytorch_lightning.metrics.functional import accuracy, hamming_loss +from pytorch_lightning.metrics.classification.utils import _input_format_classification from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, @@ -13,101 +15,140 @@ _multidim_multiclass_prob_inputs, _multilabel_inputs, _multilabel_prob_inputs, + _multilabel_multidim_prob_inputs, + _multilabel_multidim_inputs, ) from tests.metrics.utils import THRESHOLD, MetricTester torch.manual_seed(42) -def _sk_accuracy_binary_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +def _sk_accuracy(preds, target): + sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + return sk_accuracy(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_binary(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +def _sk_hamming_loss(preds, target): + sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_multilabel_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) - - -def _sk_accuracy_multilabel(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) - - -def _sk_accuracy_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize( + "metric, fn_metric, sk_metric", [(Accuracy, accuracy, _sk_accuracy), (HammingLoss, hamming_loss, _sk_hamming_loss)] +) +@pytest.mark.parametrize( + "preds, target", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target), + (_binary_inputs.preds, _binary_inputs.target), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), + (_multilabel_inputs.preds, _multilabel_inputs.target), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target), + (_multiclass_inputs.preds, _multiclass_inputs.target), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), + ], +) +class TestAccuracies(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, metric, sk_metric, fn_metric): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=metric, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args={"threshold": THRESHOLD}, + ) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + def test_accuracy_fn(self, preds, target, metric, sk_metric, fn_metric): + self.run_functional_metric_test( + preds, + target, + metric_functional=fn_metric, + sk_metric=sk_metric, + metric_args={"threshold": THRESHOLD}, + ) -def _sk_accuracy_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +l1to4 = [.1, .2, .3, .4] +l1to4t3 = np.array([l1to4, l1to4, l1to4]) +l1to4t3_mc = [l1to4t3.T, l1to4t3.T, l1to4t3.T] - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +# The preds in these examples always put highest probability on class 3, second highest on class 2, +# third highest on class 1, and lowest on class 0 +topk_preds_mc = torch.tensor([l1to4t3, l1to4t3]).float() +topk_target_mc = torch.tensor([[1, 2, 3], [2, 1, 0]]) +# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) +topk_preds_mdmc = torch.tensor([l1to4t3_mc, l1to4t3_mc]).float() +topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) -def _sk_accuracy_multidim_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +# Replace with a proper sk_metric test once sklearn 0.24 hits :) +@pytest.mark.parametrize( + "preds, target, exp_result, k, mdmc_accuracy", + [ + (topk_preds_mc, topk_target_mc, 1 / 6, 1, "global"), + (topk_preds_mc, topk_target_mc, 3 / 6, 2, "global"), + (topk_preds_mc, topk_target_mc, 5 / 6, 3, "global"), + (topk_preds_mc, topk_target_mc, 1 / 6, 1, "subset"), + (topk_preds_mc, topk_target_mc, 3 / 6, 2, "subset"), + (topk_preds_mc, topk_target_mc, 5 / 6, 3, "subset"), + (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, "global"), + (topk_preds_mdmc, topk_target_mdmc, 8 / 18, 2, "global"), + (topk_preds_mdmc, topk_target_mdmc, 13 / 18, 3, "global"), + (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, "subset"), + (topk_preds_mdmc, topk_target_mdmc, 2 / 6, 2, "subset"), + (topk_preds_mdmc, topk_target_mdmc, 3 / 6, 3, "subset"), + ], +) +def test_topk_accuracy(preds, target, exp_result, k, mdmc_accuracy): + topk = Accuracy(top_k=k, mdmc_accuracy=mdmc_accuracy) + for batch in range(preds.shape[0]): + topk(preds[batch], target[batch]) -def _sk_accuracy_multidim_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + assert topk.compute() == exp_result - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + # Test functional + total_samples = target.shape[0] * target.shape[1] + preds = preds.view(total_samples, 4, -1) + target = target.view(total_samples, -1) -def test_accuracy_invalid_shape(): - with pytest.raises(ValueError): - acc = Accuracy() - acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3)) + assert accuracy(preds, target, top_k=k, mdmc_accuracy=mdmc_accuracy) == exp_result -@pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# Only MC and MDMC with probs input type should be accepted @pytest.mark.parametrize( - "preds, target, sk_metric", + "preds, target", [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_accuracy_binary_prob), - (_binary_inputs.preds, _binary_inputs.target, _sk_accuracy_binary), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob), - (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_accuracy_multilabel), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob), - (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_accuracy_multiclass), - ( - _multidim_multiclass_prob_inputs.preds, - _multidim_multiclass_prob_inputs.target, - _sk_accuracy_multidim_multiclass_prob, - ), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _sk_accuracy_multidim_multiclass), + (_binary_prob_inputs.preds, _binary_prob_inputs.target), + (_binary_inputs.preds, _binary_inputs.target), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), + (_multilabel_inputs.preds, _multilabel_inputs.target), + (_multiclass_inputs.preds, _multiclass_inputs.target), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), ], ) -class TestAccuracy(MetricTester): - def test_accuracy(self, ddp, dist_sync_on_step, preds, target, sk_metric): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=Accuracy, - sk_metric=sk_metric, - dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD}, - ) +def test_topk_accuracy_wrong_input_types(preds, target): + topk = Accuracy(top_k=2) + + with pytest.raises(ValueError): + topk(preds[0], target[0]) + + with pytest.raises(ValueError): + accuracy(preds[0], target[0], top_k=2) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 34abee8473863..b0010916b6476 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -21,10 +21,10 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -41,23 +41,23 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -71,8 +71,8 @@ def _class_test( if metric.dist_sync_on_step: if rank == 0: - ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) - ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: @@ -87,8 +87,8 @@ def _class_test( result = metric.compute() assert isinstance(result, torch.Tensor) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) # assert after aggregation @@ -101,17 +101,17 @@ def _functional_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = {}, - atol: float = 1e-8 + atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -124,22 +124,23 @@ def _functional_test( class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ + atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: - set_start_method('spawn') + set_start_method("spawn") except RuntimeError: pass self.poolSize = NUM_PROCESSES @@ -157,24 +158,26 @@ def run_functional_metric_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {} + metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ - _functional_test(preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=sk_metric, - metric_args=metric_args, - atol=self.atol) + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) def run_class_metric_test( self, @@ -188,22 +191,22 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": From 0282f3c74c44aaeb04fb8036d909514f009faf70 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:40:33 +0100 Subject: [PATCH 04/61] Add stat scores --- docs/source/metrics.rst | 67 ++++- pytorch_lightning/metrics/__init__.py | 3 +- .../metrics/classification/__init__.py | 1 + .../metrics/classification/stat_scores.py | 230 +++++++++++++++ .../metrics/functional/__init__.py | 3 +- .../metrics/functional/stat_scores.py | 274 ++++++++++++++++++ .../classification/test_stat_scores.py | 0 tests/metrics/utils.py | 161 +++++----- 8 files changed, 648 insertions(+), 91 deletions(-) create mode 100644 pytorch_lightning/metrics/classification/stat_scores.py create mode 100644 pytorch_lightning/metrics/functional/stat_scores.py create mode 100644 tests/metrics/classification/test_stat_scores.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 407b64d3d2948..01d95436ac30e 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -257,7 +257,56 @@ binary tensors. Or it could be the other way around, you want to treat binary/mu inputs as 2-class (multi-dimensional) multi-class inputs. For these cases, the metrics where this distinction would make a difference, expose the -``is_multiclass`` argument. +``is_multiclass`` argument. Let's see how this is used with the +:class:`~pytorch_lightning.metrics.classification.StatScores` metric. + +.. testcode:: + :skipif: True + + from pytorch_lightning.metrics import StatScores + + # These inputs are supposed to be binary, but appear as multi-class + mc_binary_preds = torch.tensor([0,1,0]) + mc_binary_target = torch.tensor([1,1,0]) + +First, let's check that what happens usually - without setting ``is_multiclass`` flag. + +.. testcode:: + :skipif: True + + # Treating inputs as they appear (multi-class) + stat_scores_mc = StatScores(average='none', num_classes=2) + stat_scores_mc(mc_binary_preds, mc_binary_target) + +Out: + +.. testoutput:: + :skipif: True + + torch.tensor([[1, 1, 1, 0, 1]], + [[1, 0, 1, 1, 2]], dtype=torch.int32) + +As expected, the metric interpreted the inputs as 2 class multi-class inputs (note that if +change ``num_classes`` above to anything but 2 we would get an error). Now let's see what +happens when we set ``is_multiclass=False``: + +.. testcode:: + :skipif: True + + # Treating inputs as binary + stat_scores_binary = StatScores(average='none', num_classes=1, + is_multiclass=False) + stat_scores_binary(mc_binary_preds, mc_binary_target) + +Out: + +.. testoutput:: + :skipif: True + + torch.tensor([[1, 0, 1, 1, 2]], dtype=torch.int32) + +Now the metric correctly interpreted the inputs as binary, and thus returned result +only for one class. Class Metrics (Classification) ------------------------------ @@ -298,6 +347,13 @@ ConfusionMatrix .. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix :noindex: +StatScores +~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.StatScores + :noindex: + + Functional Metrics (Classification) ----------------------------------- @@ -416,14 +472,7 @@ roc [func] stat_scores [func] ~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.stat_scores - :noindex: - - -stat_scores_multiple_classes [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.classification.stat_scores_multiple_classes +.. autofunction:: pytorch_lightning.metrics.functional.stat_scores :noindex: diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 653ad23c68f7e..2fe01845719a2 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -19,7 +19,8 @@ Recall, FBeta, F1, - ConfusionMatrix + ConfusionMatrix, + StatScores ) from pytorch_lightning.metrics.regression import ( diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index db643c227abed..0bf82c7201231 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -15,3 +15,4 @@ from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix +from pytorch_lightning.metrics.classification.stat_scores import StatScores diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py new file mode 100644 index 0000000000000..bfb4058d7e898 --- /dev/null +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -0,0 +1,230 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any, Callable + +import torch +from pytorch_lightning.metrics.utils import dim_zero_cat +from pytorch_lightning.metrics import Metric +from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update, _stat_scores_compute + + +def _dim_zero_cat_and_put_back(tensor: torch.Tensor): + """ Needed as we don't need the process dimension in sync reduce """ + + out = dim_zero_cat(tensor) + out = out.reshape(-1, *out.shape[2:]) + + return out + + +class StatScores(Metric): + """Computes the number of true positives, false positives, true negatives, false negatives. + + The reduction method (how the statistics are aggregated) is controlled by the + ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + Args: + reduce: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] + combinations (globally). Produces a one element tensor for each statistic. + - ``'macro'``: Counts the statistics for each class separately (over all samples). + Produces a ``(C, )`` 1d tensor. Requires ``num_classes`` to be set. + - ``'samples'``: Counts the statistics for each sample separately (over all classes). + Produces a ``(N, )`` 1d tensor. + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_reduce``. + + mdmc_reduce: + Defines how the multi-dimensional multi-class inputs are handeled. Should be + one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then concatenating the outputs together. This is + done by, for each sample, treating the flattened extra axes ``...`` (see + :ref:`metrics:Input types`) as the ``N`` dimension within the sample, and computing + the statistics for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``reduce='macro'``, the class statistics for the ignored + class will all be returned as ``nan`` (to not break the indexing of other labels). + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + Example: + + >>> from pytorch_lightning.metrics.classification import StatScores + >>> preds = torch.tensor([1, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> stat_scores = StatScores(reduce='macro', num_classes=3) + >>> stat_scores(preds, target) + tensor([[0, 1, 2, 1, 1], + [1, 1, 1, 1, 2], + [1, 0, 3, 0, 1]]) + >>> stat_scores = StatScores(reduce='micro') + >>> stat_scores(preds, target) + tensor([2, 2, 6, 2, 4]) + + """ + + def __init__( + self, + reduce: str = "micro", + mdmc_reduce: Optional[str] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.reduce = reduce + self.mdmc_reduce = mdmc_reduce + self.num_classes = num_classes + self.threshold = threshold + self.is_multiclass = is_multiclass + self.ignore_index = ignore_index + + if reduce not in ["micro", "macro", "samples"]: + raise ValueError("reduce %s is not valid." % reduce) + + if mdmc_reduce not in [None, "samplewise", "global"]: + raise ValueError("mdmc_reduce %s is not valid." % mdmc_reduce) + + if reduce == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set reduce as macro, you have to provide the number of classes.") + + if mdmc_reduce != "samplewise": + if reduce == "micro": + default, reduce_fn = torch.tensor(0), "sum" + elif reduce == "macro": + default, reduce_fn = torch.zeros((num_classes,), dtype=torch.int), "sum" + elif reduce == "samples": + default, reduce_fn = torch.empty(0), _dim_zero_cat_and_put_back + else: + default, reduce_fn = torch.empty(0), _dim_zero_cat_and_put_back + + for s in ("tp", "fp", "tn", "fn"): + self.add_state(s, default=default.detach().clone(), dist_reduce_fx=reduce_fn) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + """ + + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=self.reduce, + mdmc_reduce=self.mdmc_reduce, + threshold=self.threshold, + num_classes=self.num_classes, + is_multiclass=self.is_multiclass, + ignore_index=self.ignore_index, + ) + + # Update states + if self.reduce != "samples" and self.mdmc_reduce != "samplewise": + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + else: + self.tp = torch.cat((self.tp, tp)) + self.fp = torch.cat((self.fp, fp)) + self.tn = torch.cat((self.tn, tn)) + self.fn = torch.cat((self.fn, fn)) + + def compute(self) -> torch.Tensor: + """ + Computes the stat scores based on inputs passed in to ``update`` previously. + + Return: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The + shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional + multi-class data) parameters: + + - If the data is not multi-dimensional multi-class, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)``, + where ``C`` stands for the number of classes + - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for + the number of samples + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for + the product of sizes of all "extra" dimensions of the data (i.e. all dimensions + except for ``C`` and ``N``) + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then + + - If ``reduce='micro'``, the shape will be ``(N, 5)`` + - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` + + """ + + return _stat_scores_compute(self.tp, self.fp, self.tn, self.fn) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 3bb5313db7b27..8147e4c3cfcd5 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -25,8 +25,6 @@ precision_recall_curve, recall, roc, - stat_scores, - stat_scores_multiple_classes, to_categorical, to_onehot, iou, @@ -44,3 +42,4 @@ from pytorch_lightning.metrics.functional.ssim import ssim from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 +from pytorch_lightning.metrics.functional.stat_scores import stat_scores diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py new file mode 100644 index 0000000000000..63fcd57ff9bf7 --- /dev/null +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -0,0 +1,274 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union, Optional + +import torch +from pytorch_lightning.metrics.classification.utils import _input_format_classification + + +def _del_column(tensor: torch.Tensor, index: int): + """ Delete the column at index.""" + + return torch.cat([tensor[:, :index], tensor[:, (index + 1) :]], 1) + + +def _stat_scores( + preds: torch.Tensor, target: torch.Tensor, reduce: str = "micro" +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate the number of tp, fp, tn, fn. + + The shape of the returned tensors depnds on the shape of the inputs + and the `reduce` parameter: + + * If inputs are of the shape (N, C), then + + * If reduce is 'micro', the returned tensors are 1 element tensors + * If reduce is one of 'macro', 'weighted', 'none' or None, the returned + tensors are (C,) 1d tensors + * If reduce is 'samples, the returned tensors are 1d (N,) tensors + + * If inputs are of the shape (N, C, X), then + + * If reduce is 'micro', the returned tensors are (N,) 1d tensors + * If reduce is one of 'macro', 'weighted', 'none' or None, the returned + tensors are (N,C) 2d tensors + * If reduce is 'samples, the returned tensors are 1d (N,X) 2d tensors + + Parameters + ---------- + labels + An (N, C) or (N, C, X) tensor of true labels (0 or 1) + preds + An (N, C) or (N, C, X) tensor of predictions (0 or 1) + reduce + One of 'micro', 'macro', 'samples' + + Returns + ------- + tp, fp, tn, fn + """ + is_multidim = len(preds.shape) == 3 + + if reduce == "micro": + dim = [0, 1] if not is_multidim else [1, 2] + elif reduce == "macro": + dim = 0 if not is_multidim else 2 + elif reduce == "samples": + dim = 1 + + true_pred, false_pred = target == preds, target != preds + + tp = (true_pred * (preds == 1)).sum(dim=dim) + fp = (false_pred * (preds == 1)).sum(dim=dim) + + tn = (true_pred * (preds == 0)).sum(dim=dim) + fn = (false_pred * (preds == 0)).sum(dim=dim) + + return tp.int(), fp.int(), tn.int(), fn.int() + + +def _stat_scores_update( + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", + mdmc_reduce: Optional[str] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + preds, target, _ = _input_format_classification( + preds, + target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + ) + + if len(preds.shape) == 3: + if not mdmc_reduce: + raise ValueError( + "When your inputs are multi-dimensional multi-class," + "you have to set mdmc_reduce to either 'samplewise' or 'global'" + ) + if mdmc_reduce == "global": + preds = torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) + target = torch.movedim(target, 1, -1).reshape(-1, target.shape[1]) + + # Delete what is in ignore_index, if applicable (and classes don't matter): + if ignore_index and reduce in ["micro", "samples"] and preds.shape[1] > 1: + if 0 <= ignore_index < preds.shape[1]: + preds = _del_column(preds, ignore_index) + target = _del_column(target, ignore_index) + + tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce) + + # Take care of ignore_index + if ignore_index and reduce == "macro": + if num_classes > 1 and 0 <= ignore_index < num_classes: + if mdmc_reduce == "global" or not mdmc_reduce: + tp[ignore_index] = -1 + fp[ignore_index] = -1 + tn[ignore_index] = -1 + fn[ignore_index] = -1 + else: + tp[:, ignore_index] = -1 + fp[:, ignore_index] = -1 + tn[:, ignore_index] = -1 + fn[:, ignore_index] = -1 + + return tp, fp, tn, fn + + +def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor: + + outputs = [ + tp.unsqueeze(-1), + fp.unsqueeze(-1), + tn.unsqueeze(-1), + fn.unsqueeze(-1), + tp.unsqueeze(-1) + fn.unsqueeze(-1), # support + ] + outputs = torch.cat(outputs, -1).long() + + # To standardzie ignore_index statistics as -1 + outputs = torch.where(outputs < 0, -1, outputs) + + return outputs + + +def stat_scores( + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", + mdmc_reduce: Optional[str] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, +) -> torch.Tensor: + """Computes the number of true positives, false positives, true negatives, false negatives. + + The reduction method (how the statistics are aggregated) is controlled by the + ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + reduce: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] + combinations (globally). Produces a one element tensor for each statistic. + - ``'macro'``: Counts the statistics for each class separately (over all samples). + Produces a ``(C, )`` 1d tensor. Requires ``num_classes`` to be set. + - ``'samples'``: Counts the statistics for each sample separately (over all classes). + Produces a ``(N, )`` 1d tensor. + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_reduce``. + + mdmc_reduce: + Defines how the multi-dimensional multi-class inputs are handeled. Should be + one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then concatenating the outputs together. This is + done by, for each sample, treating the flattened extra axes ``...`` (see + :ref:`metrics:Input types`) as the ``N`` dimension within the sample, and computing + the statistics for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``reduce='macro'``, the class statistics for the ignored + class will all be returned as ``nan`` (to not break the indexing of other labels). + + Return: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The + shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional + multi-class data) parameters: + + - If the data is not multi-dimensional multi-class, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)``, + where ``C`` stands for the number of classes + - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for + the number of samples + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for + the product of sizes of all "extra" dimensions of the data (i.e. all dimensions + except for ``C`` and ``N``) + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then + + - If ``reduce='micro'``, the shape will be ``(N, 5)`` + - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` + + Example: + + >>> from pytorch_lightning.metrics.functional import stat_scores + >>> preds = torch.tensor([1, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> stat_scores(preds, target, reduce='macro', num_classes=3) + tensor([[0, 1, 2, 1, 1], + [1, 1, 1, 1, 2], + [1, 0, 3, 0, 1]]) + >>> stat_scores(preds, target, reduce='micro') + tensor([2, 2, 6, 2, 4]) + + """ + + if reduce not in ["micro", "macro", "samples"]: + raise ValueError("reduce %s is not valid." % reduce) + + if mdmc_reduce not in [None, "samplewise", "global"]: + raise ValueError("mdmc_reduce %s is not valid." % mdmc_reduce) + + if reduce == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set reduce as macro, you have to provide the number of classes.") + + tp, fp, tn, fn = _stat_scores_update( + preds, target, reduce, mdmc_reduce, threshold, num_classes, is_multiclass, ignore_index + ) + return _stat_scores_compute(tp, fp, tn, fn) diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 34abee8473863..8ec14c41b1360 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -21,10 +21,10 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -41,23 +41,23 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -71,28 +71,28 @@ def _class_test( if metric.dist_sync_on_step: if rank == 0: - ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) - ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) else: sk_batch_result = sk_metric(preds[i], target[i]) # assert for batch if check_batch: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) # check on all batches on all ranks result = metric.compute() assert isinstance(result, torch.Tensor) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) # assert after aggregation - assert np.allclose(result.numpy(), sk_result, atol=atol) + assert np.allclose(result.numpy(), sk_result, atol=atol, equal_nan=True) def _functional_test( @@ -101,17 +101,17 @@ def _functional_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = {}, - atol: float = 1e-8 + atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -120,26 +120,27 @@ def _functional_test( sk_result = sk_metric(preds[i], target[i]) # assert its the same - assert np.allclose(lightning_result.numpy(), sk_result, atol=atol) + assert np.allclose(lightning_result.numpy(), sk_result, atol=atol, equal_nan=True) class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ + atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: - set_start_method('spawn') + set_start_method("spawn") except RuntimeError: pass self.poolSize = NUM_PROCESSES @@ -157,24 +158,26 @@ def run_functional_metric_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {} + metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ - _functional_test(preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=sk_metric, - metric_args=metric_args, - atol=self.atol) + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) def run_class_metric_test( self, @@ -188,22 +191,22 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": From 55fdaaf16a185d5ddf2cf37871f389658b8de2c4 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:41:43 +0100 Subject: [PATCH 05/61] Change testing utils --- tests/metrics/utils.py | 161 +++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 79 deletions(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 34abee8473863..8ec14c41b1360 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -21,10 +21,10 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -41,23 +41,23 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -71,28 +71,28 @@ def _class_test( if metric.dist_sync_on_step: if rank == 0: - ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) - ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) else: sk_batch_result = sk_metric(preds[i], target[i]) # assert for batch if check_batch: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) # check on all batches on all ranks result = metric.compute() assert isinstance(result, torch.Tensor) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) # assert after aggregation - assert np.allclose(result.numpy(), sk_result, atol=atol) + assert np.allclose(result.numpy(), sk_result, atol=atol, equal_nan=True) def _functional_test( @@ -101,17 +101,17 @@ def _functional_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = {}, - atol: float = 1e-8 + atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -120,26 +120,27 @@ def _functional_test( sk_result = sk_metric(preds[i], target[i]) # assert its the same - assert np.allclose(lightning_result.numpy(), sk_result, atol=atol) + assert np.allclose(lightning_result.numpy(), sk_result, atol=atol, equal_nan=True) class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ + atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: - set_start_method('spawn') + set_start_method("spawn") except RuntimeError: pass self.poolSize = NUM_PROCESSES @@ -157,24 +158,26 @@ def run_functional_metric_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {} + metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ - _functional_test(preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=sk_metric, - metric_args=metric_args, - atol=self.atol) + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) def run_class_metric_test( self, @@ -188,22 +191,22 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": From 5cbf56a5422d0efc29ad84d85d5768863692496b Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:50:50 +0100 Subject: [PATCH 06/61] Replace len(*.shape) with *.ndim --- .../metrics/classification/utils.py | 20 +++++++++---------- pytorch_lightning/metrics/utils.py | 16 +++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index b8f5af2e988d8..62fe3e2095a78 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -90,25 +90,25 @@ def _check_classification_inputs( raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") # Check that shape/types fall into one of the cases - if len(preds.shape) == len(target.shape): + if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") if preds_float and target.max() > 1: raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") # Get the case - if len(preds.shape) == 1 and preds_float: + if preds.ndim == 1 and preds_float: case = "binary" - elif len(preds.shape) == 1 and not preds_float: + elif preds.ndim == 1 and not preds_float: case = "multi-class" - elif len(preds.shape) > 1 and preds_float: + elif preds.ndim > 1 and preds_float: case = "multi-label" else: case = "multi-dim multi-class" implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) - elif len(preds.shape) == len(target.shape) + 1: + elif preds.ndim == target.ndim + 1: if not preds_float: raise ValueError("if preds have one dimension more than target, preds should be a float tensor") if not preds.shape[:-1] == target.shape: @@ -120,7 +120,7 @@ def _check_classification_inputs( extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] - if len(preds.shape) == 2: + if preds.ndim == 2: case = "multi-class" else: case = "multi-dim multi-class" @@ -299,7 +299,7 @@ def _input_format_classification( preds_float = preds.is_floating_point() - if len(preds.shape) == len(target.shape) == 1 and preds_float: + if preds.ndim == target.ndim == 1 and preds_float: mode = "binary" preds = (preds >= threshold).int() @@ -310,7 +310,7 @@ def _input_format_classification( preds = preds.unsqueeze(-1) target = target.unsqueeze(-1) - elif len(preds.shape) == len(target.shape) and preds_float: + elif preds.ndim == target.ndim and preds_float: mode = "multi-label" preds = (preds >= threshold).int() @@ -321,7 +321,7 @@ def _input_format_classification( preds = preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) - elif len(preds.shape) == len(target.shape) + 1 == 2: + elif preds.ndim == target.ndim + 1 == 2: mode = "multi-class" if not num_classes: num_classes = preds.shape[1] @@ -334,7 +334,7 @@ def _input_format_classification( target = target[:, [1]] preds = preds[:, [1]] - elif len(preds.shape) == len(target.shape) == 1 and not preds_float: + elif preds.ndim == target.ndim == 1 and not preds_float: mode = "multi-class" if not num_classes: diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 1ce56b30cf9e5..0f71c531fe6ac 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -109,14 +109,14 @@ def _input_format_classification( preds: tensor with labels target: tensor with labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + if preds.ndim == target.ndim and preds.dtype == torch.float: # binary or multilabel probablities preds = (preds >= threshold).long() return preds, target @@ -139,24 +139,24 @@ def _input_format_classification_one_hot( preds: one hot tensor of shape [num_classes, -1] with predicted labels target: one hot tensors of shape [num_classes, -1] with true labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.long and num_classes > 1 and not multilabel: + if preds.ndim == target.ndim and preds.dtype == torch.long and num_classes > 1 and not multilabel: # multi-class preds = to_onehot(preds, num_classes=num_classes) target = to_onehot(target, num_classes=num_classes) - elif len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + elif preds.ndim == target.ndim and preds.dtype == torch.float: # binary or multilabel probablities preds = (preds >= threshold).long() # transpose class as first dim and reshape - if len(preds.shape) > 1: + if preds.ndim > 1: preds = preds.transpose(1, 0) target = target.transpose(1, 0) From 9c33d0b3c14e50a984c59596eff7f6a513f19e2c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:55:56 +0100 Subject: [PATCH 07/61] More descriptive error message for input formatting --- pytorch_lightning/metrics/classification/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 62fe3e2095a78..6a47c7adbbae1 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -76,9 +76,7 @@ def _check_classification_inputs( if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError( - "preds should be probabilities, but values were detected outside of [0,1] range" - ) + raise ValueError("preds should be probabilities, but values were detected outside of [0,1] range") if threshold > 1 or threshold < 0: raise ValueError("Threshold should be a probability in [0,1]") @@ -92,7 +90,10 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases if preds.ndim == target.ndim: if preds.shape != target.shape: - raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") + raise ValueError( + "preds and targets should have the same shape", + f" got preds shape = {preds.shape} and target shape = {target.shape}.", + ) if preds_float and target.max() > 1: raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") From 65622058778a8186bdc8b7f8852fc4690405de09 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:24:22 +0100 Subject: [PATCH 08/61] Replace movedim with permute --- pytorch_lightning/metrics/classification/utils.py | 10 +++++++--- tests/metrics/classification/test_inputs.py | 7 ++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 6a47c7adbbae1..f58e61dcd3127 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Tuple, Optional -import numpy as np import torch from pytorch_lightning.metrics.utils import to_onehot, select_topk @@ -256,7 +255,8 @@ def _input_format_classification( dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. If this is not the case, you should move it from the last to second place using - ``torch.movedim(preds, -1, 1)``. + ``torch.movedim(preds, -1, 1)``, or using ``preds.permute``, if you are using an older + version of Pytorch. Note that where a one-hot transformation needs to be performed and the number of classes is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be @@ -370,7 +370,11 @@ def _input_format_classification( else: mode = "multi-dim multi-class" if preds.shape[:-1] == target.shape: - preds = torch.movedim(preds, -1, 1) + shape_permute = list(range(preds.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + preds = preds.permute(*shape_permute) num_classes = preds.shape[1] diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8d17d5624fac0..19828dc07eb34 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -84,7 +84,12 @@ def top2(x): def mvdim(x): - return torch.movedim(x, -1, 1) + """ Equivalent of torch.movedim(x, -1, 1) """ + shape_permute = list(range(x.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + return x.permute(*shape_permute) # To avoid ugly black line wrapping From cbbc769cf615220fba7b1422c657c35c9d9a507a Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:28:05 +0100 Subject: [PATCH 09/61] PEP 8 compliance --- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index f9861388cceda..e56b4b9f76eb3 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -68,7 +68,7 @@ class has to be correctly predicted across all extra dimension for each sample i Args: preds: Predictions from model (probabilities, or labels) - target: Ground truth values + target: Ground truth values top_k: Number of highest probability predictions considered to find the correct label, for (multi-dimensional) multi-class inputs with probability predictions. Default 1 From 33166c5a873f69f54f480f84e37455c6c6a6bf9e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:35:31 +0100 Subject: [PATCH 10/61] WIP --- docs/source/metrics.rst | 11 +- .../classification/precision_recall.py | 324 +++++++++++------- .../metrics/functional/__init__.py | 4 +- .../metrics/functional/precision_recall.py | 276 +++++++++++++++ .../classification/test_precision_recall.py | 245 +++++++++---- 5 files changed, 653 insertions(+), 207 deletions(-) create mode 100644 pytorch_lightning/metrics/functional/precision_recall.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 01d95436ac30e..f5874d93e3d23 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -437,14 +437,7 @@ multiclass_roc [func] precision [func] ~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.precision - :noindex: - - -precision_recall [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall +.. autofunction:: pytorch_lightning.metrics.functional.precision :noindex: @@ -458,7 +451,7 @@ precision_recall_curve [func] recall [func] ~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.recall +.. autofunction:: pytorch_lightning.metrics.functional.recall :noindex: diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 40002af96f5f6..110e4c2614466 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,48 +11,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -import functools -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union -from collections.abc import Mapping, Sequence -from collections import namedtuple +from typing import Optional, Any, Callable import torch -from torch import nn -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS, _input_format_classification_one_hot +from pytorch_lightning.metrics.classification.stat_scores import StatScores +from pytorch_lightning.metrics.functional.precision_recall import _precision_compute, _recall_compute -class Precision(Metric): - """ - Computes the precision metric. - - Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - Forward accepts +class Precision(StatScores): + """Computes the precision score (the ratio ``tp / (tp + fp)``). - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` + The reduction method (how the precision scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + In case where you need to ignore a class in computing the score, anI ``ignore_index`` + parameter is availible. Args: - num_classes: Number of classes in the dataset. - beta: Beta coefficient in the F measure. - threshold: - Threshold value for binary or multi-label logits. default: 0.5 - average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and then takes the mean + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. - multilabel: If predictions are from multilabel classification. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -60,89 +92,141 @@ class Precision(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None Example: - >>> from pytorch_lightning.metrics import Precision - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> precision = Precision(num_classes=3) + >>> from pytorch_lightning.metrics.classification import Precision + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision = Precision(average='macro', num_classes=3) >>> precision(preds, target) - tensor(0.3333) + tensor(0.1667) + >>> precision = Precision(average='micro') + >>> precision(preds, target) + tensor(0.2500) """ + def __init__( self, - num_classes: int = 1, + average: str = "micro", + mdmc_average: Optional[str] = None, threshold: float = 0.5, - average: str = 'micro', - multilabel: bool = False, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) - self.num_classes = num_classes - self.threshold = threshold - self.average = average - self.multilabel = multilabel - - assert self.average in ('micro', 'macro'), \ - "average passed to the function must be either `micro` or `macro`" - - self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") - def update(self, preds: torch.Tensor, target: torch.Tensor): - preds, target = _input_format_classification_one_hot( - self.num_classes, preds, target, self.threshold, self.multilabel - ) - - # multiply because we are counting (1, 1) pair for true positives - self.true_positives += torch.sum(preds * target, dim=1) - self.predicted_positives += torch.sum(preds, dim=1) + self.zero_division = zero_division + self.average = average - def compute(self): - if self.average == 'micro': - return self.true_positives.sum().float() / (self.predicted_positives.sum() + METRIC_EPS) - elif self.average == 'macro': - return (self.true_positives.float() / (self.predicted_positives + METRIC_EPS)).mean() + def compute(self) -> torch.Tensor: + """ + Computes the precision score based on inputs passed in to ``update`` previously. + Return: + The of the returned tensor depends on the ``average`` parameter -class Recall(Metric): - """ - Computes the recall metric. + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + """ - Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. + return _precision_compute( + self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division + ) - Forward accepts - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` +class Recall(StatScores): + """Computes the recall score (the ratio ``tp / (tp + fn)``). - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + In case where you need to ignore a class in computing the score, an ``ignore_index`` + parameter is availible. Args: - num_classes: Number of classes in the dataset. - beta: Beta coefficient in the F measure. - threshold: - Threshold value for binary or multi-label logits. default: 0.5 - average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and then takes the mean + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. - multilabel: If predictions are from multilabel classification. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -150,65 +234,67 @@ class Recall(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None Example: - >>> from pytorch_lightning.metrics import Recall - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> recall = Recall(num_classes=3) + >>> from pytorch_lightning.metrics.classification import Recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> recall = Recall(average='macro', num_classes=3) >>> recall(preds, target) tensor(0.3333) + >>> recall = Recall(average='micro') + >>> recall(preds, target) + tensor(0.2500) """ + def __init__( self, - num_classes: int = 1, + average: str = "micro", + mdmc_average: Optional[str] = None, threshold: float = 0.5, - average: str = 'micro', - multilabel: bool = False, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) - self.num_classes = num_classes - self.threshold = threshold - self.average = average - self.multilabel = multilabel - - assert self.average in ('micro', 'macro'), \ - "average passed to the function must be either `micro` or `macro`" + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") - self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. + self.zero_division = zero_division + self.average = average - Args: - preds: Predictions from model - target: Ground truth values + def compute(self) -> torch.Tensor: """ - preds, target = _input_format_classification_one_hot( - self.num_classes, preds, target, self.threshold, self.multilabel - ) + Computes the recall score based on inputs passed in to ``update`` previously. - # multiply because we are counting (1, 1) pair for true positives - self.true_positives += torch.sum(preds * target, dim=1) - self.actual_positives += torch.sum(target, dim=1) + Return: + The of the returned tensor depends on the ``average`` parameter - def compute(self): + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes """ - Computes accuracy over state. - """ - if self.average == 'micro': - return self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS) - elif self.average == 'macro': - return (self.true_positives.float() / (self.actual_positives + METRIC_EPS)).mean() + + return _recall_compute(self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 8147e4c3cfcd5..5055b21d87197 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -20,10 +20,7 @@ multiclass_precision_recall_curve, multiclass_roc, multiclass_auroc, - precision, - precision_recall, precision_recall_curve, - recall, roc, to_categorical, to_onehot, @@ -43,3 +40,4 @@ from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 from pytorch_lightning.metrics.functional.stat_scores import stat_scores +from pytorch_lightning.metrics.functional.precision_recall import precision, recall diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py new file mode 100644 index 0000000000000..e7fd2cca3fca8 --- /dev/null +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -0,0 +1,276 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from pytorch_lightning.metrics.functional.reduction import _reduce_scores +from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update + + +def _precision_compute( + tp: torch.Tensor, + fp: torch.Tensor, + tn: torch.Tensor, + fn: torch.Tensor, + average: str, + mdmc_average: Optional[str], + zero_division: int, +) -> torch.Tensor: + return _reduce_scores( + numerator=tp, + denominator=tp + fp, + weights=tp + fn, + average=average, + mdmc_average=mdmc_average, + zero_division=zero_division, + ) + + +def precision( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, +) -> torch.Tensor: + """Computes the precision score (the ratio ``tp / (tp + fp)``). + + The reduction method (how the precision scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + In case where you need to ignore a class in computing the score, anI ``ignore_index`` + parameter is availible. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. + + Return: + The of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + + Example: + + >>> from pytorch_lightning.metrics.functional import precision + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision(preds, target, average='macro', num_classes=3) + tensor(0.1667) + >>> precision(preds, target, average='micro') + tensor(0.2500) + + """ + + reduce = "macro" if average in ["weighted", "none", None] else average + + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") + + tp, fp, tn, fn = _stat_scores_update( + preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index + ) + + return _precision_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) + + +def _recall_compute( + tp: torch.Tensor, + fp: torch.Tensor, + tn: torch.Tensor, + fn: torch.Tensor, + average: str, + mdmc_average: Optional[str], + zero_division: int, +) -> torch.Tensor: + return _reduce_scores( + numerator=tp, + denominator=tp + fn, + weights=tp + fn, + average=average, + mdmc_average=mdmc_average, + zero_division=zero_division, + ) + + +def recall( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, +) -> torch.Tensor: + """Computes the recall score (the ratio ``tp / (tp + fn)``). + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + In case where you need to ignore a class in computing the score, an ``ignore_index`` + parameter is availible. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. + + Return: + The of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + + Example: + + >>> from pytorch_lightning.metrics.functional import recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> recall(preds, target, average='macro', num_classes=3) + tensor(0.3333) + >>> recall(preds, target, average='micro') + tensor(0.2500) + + """ + + reduce = "macro" if average in ["weighted", "none", None] else average + + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") + + tp, fp, tn, fn = _stat_scores_update( + preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index + ) + + return _recall_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 967bc60e28307..f43a7f41f0bec 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,126 +5,219 @@ import torch from sklearn.metrics import precision_score, recall_score +from pytorch_lightning.metrics.classification.utils import _input_format_classification from pytorch_lightning.metrics import Precision, Recall +from pytorch_lightning.metrics.functional import precision, recall from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, _multiclass_inputs, - _multiclass_prob_inputs, - _multidim_multiclass_inputs, - _multidim_multiclass_prob_inputs, - _multilabel_inputs, - _multilabel_prob_inputs, + _multiclass_prob_inputs as _mc_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, + _multilabel_inputs as _ml, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_prob_inputs as _mlmd_prob, + _multilabel_multidim_inputs as _mlmd, ) -from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester +from tests.metrics.utils import EXTRA_DIM, NUM_CLASSES, THRESHOLD, MetricTester torch.manual_seed(42) -def _sk_prec_recall_binary_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +def _sk_prec_recall( + preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average=None +): + if average == "none": + average = None + if num_classes == 1: + average = "binary" - return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') + labels = list(range(num_classes)) + try: + labels.remove(ignore_index) + except ValueError: + pass + sk_preds, sk_target, _ = _input_format_classification( + preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() -def _sk_prec_recall_binary(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=zero_division, labels=labels) - return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') + if len(labels) != num_classes and not average: + sk_scores = np.insert(sk_scores, ignore_index, np.nan) + return sk_scores -def _sk_prec_recall_multilabel_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1, NUM_CLASSES).numpy() - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) +def _sk_prec_recall_mdmc( + preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average +): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + if mdmc_average == "global": + preds = torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) + target = torch.movedim(target, 1, -1).reshape(-1, target.shape[1]) -def _sk_prec_recall_multilabel(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1, NUM_CLASSES).numpy() - sk_target = target.view(-1, NUM_CLASSES).numpy() + return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, zero_division, ignore_index) + else: # mdmc_average == "samplewise" + scores = [] - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + for i in range(preds.shape[0]): + pred_i = preds[i, ...].T + target_i = target[i, ...].T + scores_i = _sk_prec_recall( + pred_i, target_i, sk_fn, num_classes, average, False, zero_division, ignore_index + ) + scores.append(np.expand_dims(scores_i, 0)) -def _sk_prec_recall_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() + return np.concatenate(scores).mean() - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) +@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) +def test_wrong_params(metric, fn_metric): + with pytest.raises(ValueError): + metric(zero_division=None) -def _sk_prec_recall_multiclass(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + with pytest.raises(ValueError): + fn_metric(_binary_inputs.preds[0], _binary_inputs.target[0], zero_division=None) - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) +###################################################################################### +# Testing for MDMC inputs is partially skipped, because some cases appear where +# (with mdmc_average1 =! None, ignore_index=1, average='weighted') a sample in +# target contains only labels "1" - and as we are ignoring this index, weights of +# all labels will be zero. In this special edge case, sklearn handles the situation +# differently for each metric (recall, precision, fscore), which breaks ours handling +# everything in _reduce_scores (where the return value is 0 in this situation). +###################################################################################### -def _sk_prec_recall_multidim_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) - - -@pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("dist_sync_on_step", [True, False]) -@pytest.mark.parametrize("average", ['micro', 'macro']) @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes, multilabel", - [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False), - (_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True), - (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False), - (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False), - ( - _multidim_multiclass_prob_inputs.preds, - _multidim_multiclass_prob_inputs.target, - _sk_prec_recall_multidim_multiclass_prob, - NUM_CLASSES, - False, - ), - ( - _multidim_multiclass_inputs.preds, - _multidim_multiclass_inputs.target, - _sk_prec_recall_multidim_multiclass, - NUM_CLASSES, - False, - ), - ], + "metric_class, sk_fn, metric_fn", [(Precision, precision_score, precision), (Recall, recall_score, recall)] ) +@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) +@pytest.mark.parametrize("zero_division", [0, 1]) +@pytest.mark.parametrize("ignore_index", [None, 1]) @pytest.mark.parametrize( - "metric_class, sk_fn", [(Precision, precision_score), (Recall, recall_score)], + "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, 1, None, None, _sk_prec_recall), + (_binary_inputs.preds, _binary_inputs.target, 1, False, None, _sk_prec_recall), + (_ml_prob.preds, _ml_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_ml.preds, _ml.target, NUM_CLASSES, False, None, _sk_prec_recall), + (_mc_prob.preds, _mc_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_multiclass_inputs.preds, _multiclass_inputs.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_mlmd_prob.preds, _mlmd_prob.target, EXTRA_DIM * NUM_CLASSES, None, None, _sk_prec_recall), + (_mlmd.preds, _mlmd.target, EXTRA_DIM * NUM_CLASSES, False, None, _sk_prec_recall), + (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), + (_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), + (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc), + (_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc), + ], ) class TestPrecisionRecall(MetricTester): - def test_precision_recall( - self, ddp, dist_sync_on_step, preds, target, sk_metric, metric_class, sk_fn, num_classes, multilabel, average + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_precision_recall_class( + self, + ddp, + dist_sync_on_step, + preds, + target, + sk_wrapper, + metric_class, + metric_fn, + sk_fn, + is_multiclass, + num_classes, + average, + mdmc_average, + zero_division, + ignore_index, ): + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn, average=average), + sk_metric=partial( + sk_wrapper, + sk_fn=sk_fn, + average=average, + num_classes=num_classes, + is_multiclass=is_multiclass, + zero_division=zero_division, + ignore_index=ignore_index, + mdmc_average=mdmc_average, + ), dist_sync_on_step=dist_sync_on_step, metric_args={ "num_classes": num_classes, "average": average, - "multilabel": multilabel, "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "zero_division": zero_division, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + check_dist_sync_on_step=True, + check_batch=True, + ) + + def test_precision_recall_fn( + self, + preds, + target, + sk_wrapper, + metric_class, + metric_fn, + sk_fn, + is_multiclass, + num_classes, + average, + mdmc_average, + zero_division, + ignore_index, + ): + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_functional_metric_test( + preds, + target, + metric_functional=metric_fn, + sk_metric=partial( + sk_wrapper, + sk_fn=sk_fn, + average=average, + num_classes=num_classes, + is_multiclass=is_multiclass, + zero_division=zero_division, + ignore_index=ignore_index, + mdmc_average=mdmc_average, + ), + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "zero_division": zero_division, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, }, - check_dist_sync_on_step=False if average == 'macro' else True, - check_batch=False if average == 'macro' else True, ) From 801abe8bc347f7ff804e2adb4c07d0f82014c76a Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:37:14 +0100 Subject: [PATCH 11/61] Add reduce_scores function --- .../metrics/functional/reduction.py | 130 ++++++++++++------ 1 file changed, 87 insertions(+), 43 deletions(-) diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py index 197b1dd7097a3..08e69baf8d3b3 100644 --- a/pytorch_lightning/metrics/functional/reduction.py +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +import numpy as np import torch @@ -28,55 +31,96 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: Raise: ValueError if an invalid reduction parameter was given """ - if reduction == 'elementwise_mean': + if reduction == "elementwise_mean": return torch.mean(to_reduce) - if reduction == 'none': + if reduction == "none": return to_reduce - if reduction == 'sum': + if reduction == "sum": return torch.sum(to_reduce) - raise ValueError('Reduction parameter unknown.') + raise ValueError("Reduction parameter unknown.") -def class_reduce(num: torch.Tensor, - denom: torch.Tensor, - weights: torch.Tensor, - class_reduction: str = 'none') -> torch.Tensor: - """ - Function used to reduce classification metrics of the form `num / denom * weights`. - For example for calculating standard accuracy the num would be number of - true positives per class, denom would be the support per class, and weights - would be a tensor of 1s +def _reduce_scores( + numerator: torch.Tensor, + denominator: torch.Tensor, + weights: torch.Tensor, + average: str, + mdmc_average: Optional[str], + zero_division: int, +) -> torch.Tensor: + """Reduces scores of type numerator/denominator (with possible weighting). - Args: - num: numerator tensor - decom: denominator tensor - weights: weights for each class - class_reduction: reduction method for multiclass problems + First, scores are computed by dividing the numerator by denominator. If + denominator is zero, then the score is set to the value of zero_division + parameters. + + If average='micro' or 'none', no reduction is needed. In case of 'none', + scores for classes whose weights are negative are set to nan. + + If average='macro' or 'weighted', the scores across each classes are + averaged (with weights). The scores for classes whose weights are + negative are ignored in averaging. - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'`` or ``None``: returns calculated metric per class + If average='samples', the scores across all samples are averaged. + In case if mdmc_average='samplewise', then the transformations mentioned + above are first applied across dimension 1, and the scores then averaged + across dimension 0. + + Parameters + ---------- + numerator + A tensor with elements that are the upper part of the quotient + denominator + A tensor with elements that are the lower part of the quotient + weights + A tensor of weights for each class - will be used for weighting + only if average='weighted'. + + If a class is to be ignored (in case of macro or weighted average), + that class should have a negative weight. If average=none or None, + classes with negative weights will get a score of nan + average + The method to average the scores. Should be one of 'micro', 'macro', + 'weighted', 'none', None, 'samples' + mdmc_average + The method to average the scores if inputs were multi-dimensional multi-class. + Should be either 'global' or 'samplewise'. If inputs were not + multi-dimensional multi-class, it should be None + zero_division + Should be either zero (if there is zero division set metric to 0), or 1W """ - valid_reduction = ('micro', 'macro', 'weighted', 'none', None) - if class_reduction == 'micro': - fraction = torch.sum(num) / torch.sum(denom) - else: - fraction = num / denom - - # We need to take care of instances where the denom can be 0 - # for some (or all) classes which will produce nans - fraction[fraction != fraction] = 0 - - if class_reduction == 'micro': - return fraction - elif class_reduction == 'macro': - return torch.mean(fraction) - elif class_reduction == 'weighted': - return torch.sum(fraction * (weights.float() / torch.sum(weights))) - elif class_reduction == 'none' or class_reduction is None: - return fraction - - raise ValueError(f'Reduction parameter {class_reduction} unknown.' - f' Choose between one of these: {valid_reduction}') + numerator, denominator = numerator.double(), denominator.double() + weights = weights.double() + + zero_div_mask = denominator == 0 + denominator = torch.where(zero_div_mask, 1.0, denominator) + + scores = numerator / denominator + scores = torch.where(zero_div_mask, float(zero_division), scores) + + ignore_mask = weights < 0 + + weights = torch.where(ignore_mask, 0.0, 1.0 if average == "macro" else weights) + weights = weights.double() + weights_sum = weights.sum(dim=-1, keepdims=True) + + # In case if we ignore the only positive class (sum of weights is 0), + # return zero_division - this is to be consistent with sklearn and + # pass the tests + weights_sum = torch.where(weights_sum == 0, 1.0, weights_sum) + weights = weights / weights_sum + + if average in ["none", None]: + scores = torch.where(ignore_mask, np.nan, scores) + + elif average in ["macro", "weighted"]: + scores = (scores * weights).sum(dim=-1) + + elif average == "samples": + scores = scores.mean() + + if mdmc_average == "samplewise": + scores = scores.mean() + + return scores.float() From fbebd345ff61348882f5ed82efa1a1785642acfb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:43:15 +0100 Subject: [PATCH 12/61] Temporarily add back legacy class_reduce --- .../metrics/functional/reduction.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py index 08e69baf8d3b3..6491ba0784116 100644 --- a/pytorch_lightning/metrics/functional/reduction.py +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -39,6 +39,46 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: return torch.sum(to_reduce) raise ValueError("Reduction parameter unknown.") +def class_reduce(num: torch.Tensor, + denom: torch.Tensor, + weights: torch.Tensor, + class_reduction: str = 'none') -> torch.Tensor: + """ + Function used to reduce classification metrics of the form `num / denom * weights`. + For example for calculating standard accuracy the num would be number of + true positives per class, denom would be the support per class, and weights + would be a tensor of 1s + Args: + num: numerator tensor + decom: denominator tensor + weights: weights for each class + class_reduction: reduction method for multiclass problems + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'`` or ``None``: returns calculated metric per class + """ + valid_reduction = ('micro', 'macro', 'weighted', 'none', None) + if class_reduction == 'micro': + fraction = torch.sum(num) / torch.sum(denom) + else: + fraction = num / denom + + # We need to take care of instances where the denom can be 0 + # for some (or all) classes which will produce nans + fraction[fraction != fraction] = 0 + + if class_reduction == 'micro': + return fraction + elif class_reduction == 'macro': + return torch.mean(fraction) + elif class_reduction == 'weighted': + return torch.sum(fraction * (weights.float() / torch.sum(weights))) + elif class_reduction == 'none' or class_reduction is None: + return fraction + + raise ValueError(f'Reduction parameter {class_reduction} unknown.' + f' Choose between one of these: {valid_reduction}') def _reduce_scores( numerator: torch.Tensor, From f45fc817e6b0cc47f2f95d6a79e582d719b90d4c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:59:55 +0100 Subject: [PATCH 13/61] Division with float --- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index e56b4b9f76eb3..fa423c611a450 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -45,7 +45,7 @@ def _accuracy_update( def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: - return correct / total + return correct.float() / total def accuracy( From 9d44a2643b64f4d5e6f4beb14c0b1ec9eab68ffc Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 00:08:15 +0100 Subject: [PATCH 14/61] PEP 8 compliance --- .../metrics/functional/reduction.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py index 6491ba0784116..8889abda95082 100644 --- a/pytorch_lightning/metrics/functional/reduction.py +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -39,10 +39,10 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: return torch.sum(to_reduce) raise ValueError("Reduction parameter unknown.") -def class_reduce(num: torch.Tensor, - denom: torch.Tensor, - weights: torch.Tensor, - class_reduction: str = 'none') -> torch.Tensor: + +def class_reduce( + num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" +) -> torch.Tensor: """ Function used to reduce classification metrics of the form `num / denom * weights`. For example for calculating standard accuracy the num would be number of @@ -58,8 +58,8 @@ def class_reduce(num: torch.Tensor, - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - ``'none'`` or ``None``: returns calculated metric per class """ - valid_reduction = ('micro', 'macro', 'weighted', 'none', None) - if class_reduction == 'micro': + valid_reduction = ("micro", "macro", "weighted", "none", None) + if class_reduction == "micro": fraction = torch.sum(num) / torch.sum(denom) else: fraction = num / denom @@ -68,17 +68,19 @@ def class_reduce(num: torch.Tensor, # for some (or all) classes which will produce nans fraction[fraction != fraction] = 0 - if class_reduction == 'micro': + if class_reduction == "micro": return fraction - elif class_reduction == 'macro': + elif class_reduction == "macro": return torch.mean(fraction) - elif class_reduction == 'weighted': + elif class_reduction == "weighted": return torch.sum(fraction * (weights.float() / torch.sum(weights))) - elif class_reduction == 'none' or class_reduction is None: + elif class_reduction == "none" or class_reduction is None: return fraction - raise ValueError(f'Reduction parameter {class_reduction} unknown.' - f' Choose between one of these: {valid_reduction}') + raise ValueError( + f"Reduction parameter {class_reduction} unknown." f" Choose between one of these: {valid_reduction}" + ) + def _reduce_scores( numerator: torch.Tensor, From 5ce7cd9c81e45a77e891166d88cbb0f0094b049b Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 25 Nov 2020 00:52:54 +0100 Subject: [PATCH 15/61] Remove precision recall --- docs/source/metrics.rst | 11 +- .../classification/precision_recall.py | 324 +++++++----------- .../metrics/functional/__init__.py | 4 +- .../metrics/functional/precision_recall.py | 276 --------------- .../classification/test_precision_recall.py | 247 +++++-------- 5 files changed, 208 insertions(+), 654 deletions(-) delete mode 100644 pytorch_lightning/metrics/functional/precision_recall.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 40a2f914d981f..e831bda41647c 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -448,7 +448,14 @@ multiclass_roc [func] precision [func] ~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.precision +.. autofunction:: pytorch_lightning.metrics.functional.classification.precision + :noindex: + + +precision_recall [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall :noindex: @@ -462,7 +469,7 @@ precision_recall_curve [func] recall [func] ~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.recall +.. autofunction:: pytorch_lightning.metrics.functional.classification.recall :noindex: diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 110e4c2614466..9d666eabf53b8 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,80 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Callable +import math +import functools +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union +from collections.abc import Mapping, Sequence +from collections import namedtuple import torch -from pytorch_lightning.metrics.classification.stat_scores import StatScores -from pytorch_lightning.metrics.functional.precision_recall import _precision_compute, _recall_compute +from torch import nn +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS, _input_format_classification_one_hot -class Precision(StatScores): - """Computes the precision score (the ratio ``tp / (tp + fp)``). +class Precision(Metric): + """ + Computes the precision metric. - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + Works with binary, multiclass, and multilabel data. + Accepts logits from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. - In case where you need to ignore a class in computing the score, anI ``ignore_index`` - parameter is availible. + Forward accepts - Args: - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. + This is the case for binary and multi-label logits. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + Args: + num_classes: Number of classes in the dataset. + beta: Beta coefficient in the F measure. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. + Threshold value for binary or multi-label logits. default: 0.5 + + average: + * `'micro'` computes metric globally + * `'macro'` computes metric for each class and then takes the mean + multilabel: If predictions are from multilabel classification. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -92,141 +60,89 @@ class Precision(StatScores): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None Example: - >>> from pytorch_lightning.metrics.classification import Precision - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision = Precision(average='macro', num_classes=3) - >>> precision(preds, target) - tensor(0.1667) - >>> precision = Precision(average='micro') + >>> from pytorch_lightning.metrics import Precision + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> precision = Precision(num_classes=3) >>> precision(preds, target) - tensor(0.2500) + tensor(0.3333) """ - def __init__( self, - average: str = "micro", - mdmc_average: Optional[str] = None, + num_classes: int = 1, threshold: float = 0.5, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: int = 0, + average: str = 'micro', + multilabel: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, ): super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, - dist_sync_fn=dist_sync_fn, ) - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") - - self.zero_division = zero_division + self.num_classes = num_classes + self.threshold = threshold self.average = average + self.multilabel = multilabel - def compute(self) -> torch.Tensor: - """ - Computes the precision score based on inputs passed in to ``update`` previously. + assert self.average in ('micro', 'macro'), \ + "average passed to the function must be either `micro` or `macro`" - Return: - The of the returned tensor depends on the ``average`` parameter + self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - """ - - return _precision_compute( - self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division + def update(self, preds: torch.Tensor, target: torch.Tensor): + preds, target = _input_format_classification_one_hot( + self.num_classes, preds, target, self.threshold, self.multilabel ) + # multiply because we are counting (1, 1) pair for true positives + self.true_positives += torch.sum(preds * target, dim=1) + self.predicted_positives += torch.sum(preds, dim=1) -class Recall(StatScores): - """Computes the recall score (the ratio ``tp / (tp + fn)``). + def compute(self): + if self.average == 'micro': + return self.true_positives.sum().float() / (self.predicted_positives.sum() + METRIC_EPS) + elif self.average == 'macro': + return (self.true_positives.float() / (self.predicted_positives + METRIC_EPS)).mean() - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - In case where you need to ignore a class in computing the score, an ``ignore_index`` - parameter is availible. +class Recall(Metric): + """ + Computes the recall metric. - Args: - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + Works with binary, multiclass, and multilabel data. + Accepts logits from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. + This is the case for binary and multi-label logits. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + beta: Beta coefficient in the F measure. threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. + Threshold value for binary or multi-label logits. default: 0.5 + average: + * `'micro'` computes metric globally + * `'macro'` computes metric for each class and then takes the mean + + multilabel: If predictions are from multilabel classification. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -234,67 +150,65 @@ class Recall(StatScores): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None Example: - >>> from pytorch_lightning.metrics.classification import Recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall = Recall(average='macro', num_classes=3) + >>> from pytorch_lightning.metrics import Recall + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> recall = Recall(num_classes=3) >>> recall(preds, target) tensor(0.3333) - >>> recall = Recall(average='micro') - >>> recall(preds, target) - tensor(0.2500) """ - def __init__( self, - average: str = "micro", - mdmc_average: Optional[str] = None, + num_classes: int = 1, threshold: float = 0.5, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: int = 0, + average: str = 'micro', + multilabel: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, ): super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, - dist_sync_fn=dist_sync_fn, ) - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") - - self.zero_division = zero_division + self.num_classes = num_classes + self.threshold = threshold self.average = average + self.multilabel = multilabel - def compute(self) -> torch.Tensor: - """ - Computes the recall score based on inputs passed in to ``update`` previously. + assert self.average in ('micro', 'macro'), \ + "average passed to the function must be either `micro` or `macro`" - Return: - The of the returned tensor depends on the ``average`` parameter + self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes + Args: + preds: Predictions from model + target: Ground truth values """ + preds, target = _input_format_classification_one_hot( + self.num_classes, preds, target, self.threshold, self.multilabel + ) + + # multiply because we are counting (1, 1) pair for true positives + self.true_positives += torch.sum(preds * target, dim=1) + self.actual_positives += torch.sum(target, dim=1) - return _recall_compute(self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division) + def compute(self): + """ + Computes accuracy over state. + """ + if self.average == 'micro': + return self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS) + elif self.average == 'macro': + return (self.true_positives.float() / (self.actual_positives + METRIC_EPS)).mean() \ No newline at end of file diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index b3af66a129283..7582a5f3d38ac 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -19,7 +19,10 @@ multiclass_precision_recall_curve, multiclass_roc, multiclass_auroc, + precision, + precision_recall, precision_recall_curve, + recall, roc, to_categorical, to_onehot, @@ -41,4 +44,3 @@ from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 from pytorch_lightning.metrics.functional.stat_scores import stat_scores -from pytorch_lightning.metrics.functional.precision_recall import precision, recall diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py deleted file mode 100644 index e7fd2cca3fca8..0000000000000 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -from pytorch_lightning.metrics.functional.reduction import _reduce_scores -from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update - - -def _precision_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, - average: str, - mdmc_average: Optional[str], - zero_division: int, -) -> torch.Tensor: - return _reduce_scores( - numerator=tp, - denominator=tp + fp, - weights=tp + fn, - average=average, - mdmc_average=mdmc_average, - zero_division=zero_division, - ) - - -def precision( - preds: torch.Tensor, - target: torch.Tensor, - average: str = "micro", - mdmc_average: Optional[str] = None, - threshold: float = 0.5, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: int = 0, -) -> torch.Tensor: - """Computes the precision score (the ratio ``tp / (tp + fp)``). - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - - In case where you need to ignore a class in computing the score, anI ``ignore_index`` - parameter is availible. - - Args: - preds: Predictions from model (probabilities or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. - - Return: - The of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Example: - - >>> from pytorch_lightning.metrics.functional import precision - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision(preds, target, average='macro', num_classes=3) - tensor(0.1667) - >>> precision(preds, target, average='micro') - tensor(0.2500) - - """ - - reduce = "macro" if average in ["weighted", "none", None] else average - - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") - - tp, fp, tn, fn = _stat_scores_update( - preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index - ) - - return _precision_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) - - -def _recall_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, - average: str, - mdmc_average: Optional[str], - zero_division: int, -) -> torch.Tensor: - return _reduce_scores( - numerator=tp, - denominator=tp + fn, - weights=tp + fn, - average=average, - mdmc_average=mdmc_average, - zero_division=zero_division, - ) - - -def recall( - preds: torch.Tensor, - target: torch.Tensor, - average: str = "micro", - mdmc_average: Optional[str] = None, - threshold: float = 0.5, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: int = 0, -) -> torch.Tensor: - """Computes the recall score (the ratio ``tp / (tp + fn)``). - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - - In case where you need to ignore a class in computing the score, an ``ignore_index`` - parameter is availible. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. - - Return: - The of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Example: - - >>> from pytorch_lightning.metrics.functional import recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall(preds, target, average='macro', num_classes=3) - tensor(0.3333) - >>> recall(preds, target, average='micro') - tensor(0.2500) - - """ - - reduce = "macro" if average in ["weighted", "none", None] else average - - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") - - tp, fp, tn, fn = _stat_scores_update( - preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index - ) - - return _recall_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index f43a7f41f0bec..8ee114e2e1ae0 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,219 +5,126 @@ import torch from sklearn.metrics import precision_score, recall_score -from pytorch_lightning.metrics.classification.utils import _input_format_classification from pytorch_lightning.metrics import Precision, Recall -from pytorch_lightning.metrics.functional import precision, recall from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, _multiclass_inputs, - _multiclass_prob_inputs as _mc_prob, - _multidim_multiclass_inputs as _mdmc, - _multidim_multiclass_prob_inputs as _mdmc_prob, - _multilabel_inputs as _ml, - _multilabel_prob_inputs as _ml_prob, - _multilabel_multidim_prob_inputs as _mlmd_prob, - _multilabel_multidim_inputs as _mlmd, + _multiclass_prob_inputs, + _multidim_multiclass_inputs, + _multidim_multiclass_prob_inputs, + _multilabel_inputs, + _multilabel_prob_inputs, ) -from tests.metrics.utils import EXTRA_DIM, NUM_CLASSES, THRESHOLD, MetricTester +from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester torch.manual_seed(42) -def _sk_prec_recall( - preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average=None -): - if average == "none": - average = None - if num_classes == 1: - average = "binary" +def _sk_prec_recall_binary_prob(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() - labels = list(range(num_classes)) - try: - labels.remove(ignore_index) - except ValueError: - pass + return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') - sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=zero_division, labels=labels) +def _sk_prec_recall_binary(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() - if len(labels) != num_classes and not average: - sk_scores = np.insert(sk_scores, ignore_index, np.nan) + return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') - return sk_scores +def _sk_prec_recall_multilabel_prob(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1, NUM_CLASSES).numpy() -def _sk_prec_recall_mdmc( - preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average -): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) - if mdmc_average == "global": - preds = torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) - target = torch.movedim(target, 1, -1).reshape(-1, target.shape[1]) - return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, zero_division, ignore_index) - else: # mdmc_average == "samplewise" - scores = [] +def _sk_prec_recall_multilabel(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1, NUM_CLASSES).numpy() + sk_target = target.view(-1, NUM_CLASSES).numpy() - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_prec_recall( - pred_i, target_i, sk_fn, num_classes, average, False, zero_division, ignore_index - ) + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) - scores.append(np.expand_dims(scores_i, 0)) - return np.concatenate(scores).mean() +def _sk_prec_recall_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) -@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) -def test_wrong_params(metric, fn_metric): - with pytest.raises(ValueError): - metric(zero_division=None) - with pytest.raises(ValueError): - fn_metric(_binary_inputs.preds[0], _binary_inputs.target[0], zero_division=None) +def _sk_prec_recall_multiclass(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) -###################################################################################### -# Testing for MDMC inputs is partially skipped, because some cases appear where -# (with mdmc_average1 =! None, ignore_index=1, average='weighted') a sample in -# target contains only labels "1" - and as we are ignoring this index, weights of -# all labels will be zero. In this special edge case, sklearn handles the situation -# differently for each metric (recall, precision, fscore), which breaks ours handling -# everything in _reduce_scores (where the return value is 0 in this situation). -###################################################################################### +def _sk_prec_recall_multidim_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("dist_sync_on_step", [True, False]) +@pytest.mark.parametrize("average", ['micro', 'macro']) @pytest.mark.parametrize( - "metric_class, sk_fn, metric_fn", [(Precision, precision_score, precision), (Recall, recall_score, recall)] -) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("zero_division", [0, 1]) -@pytest.mark.parametrize("ignore_index", [None, 1]) -@pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", + "preds, target, sk_metric, num_classes, multilabel", [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, 1, None, None, _sk_prec_recall), - (_binary_inputs.preds, _binary_inputs.target, 1, False, None, _sk_prec_recall), - (_ml_prob.preds, _ml_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_ml.preds, _ml.target, NUM_CLASSES, False, None, _sk_prec_recall), - (_mc_prob.preds, _mc_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_multiclass_inputs.preds, _multiclass_inputs.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_mlmd_prob.preds, _mlmd_prob.target, EXTRA_DIM * NUM_CLASSES, None, None, _sk_prec_recall), - (_mlmd.preds, _mlmd.target, EXTRA_DIM * NUM_CLASSES, False, None, _sk_prec_recall), - (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), - (_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), - (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc), - (_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc), + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False), + (_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True), + (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False), + (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False), + ( + _multidim_multiclass_prob_inputs.preds, + _multidim_multiclass_prob_inputs.target, + _sk_prec_recall_multidim_multiclass_prob, + NUM_CLASSES, + False, + ), + ( + _multidim_multiclass_inputs.preds, + _multidim_multiclass_inputs.target, + _sk_prec_recall_multidim_multiclass, + NUM_CLASSES, + False, + ), ], ) +@pytest.mark.parametrize( + "metric_class, sk_fn", [(Precision, precision_score), (Recall, recall_score)], +) class TestPrecisionRecall(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_class( - self, - ddp, - dist_sync_on_step, - preds, - target, - sk_wrapper, - metric_class, - metric_fn, - sk_fn, - is_multiclass, - num_classes, - average, - mdmc_average, - zero_division, - ignore_index, + def test_precision_recall( + self, ddp, dist_sync_on_step, preds, target, sk_metric, metric_class, sk_fn, num_classes, multilabel, average ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=metric_class, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - zero_division=zero_division, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), + sk_metric=partial(sk_metric, sk_fn=sk_fn, average=average), dist_sync_on_step=dist_sync_on_step, metric_args={ "num_classes": num_classes, "average": average, + "multilabel": multilabel, "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "zero_division": zero_division, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, - check_dist_sync_on_step=True, - check_batch=True, - ) - - def test_precision_recall_fn( - self, - preds, - target, - sk_wrapper, - metric_class, - metric_fn, - sk_fn, - is_multiclass, - num_classes, - average, - mdmc_average, - zero_division, - ignore_index, - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - - self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - zero_division=zero_division, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "zero_division": zero_division, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, }, - ) + check_dist_sync_on_step=False if average == 'macro' else True, + check_batch=False if average == 'macro' else True, + ) \ No newline at end of file From 3b702701dfa26c6faad60f3cc4721d8c8121c941 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 25 Nov 2020 00:54:27 +0100 Subject: [PATCH 16/61] Replace movedim with permute --- pytorch_lightning/metrics/functional/stat_scores.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 63fcd57ff9bf7..1f5e3b50bdf7a 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -104,8 +104,12 @@ def _stat_scores_update( "you have to set mdmc_reduce to either 'samplewise' or 'global'" ) if mdmc_reduce == "global": - preds = torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) - target = torch.movedim(target, 1, -1).reshape(-1, target.shape[1]) + shape_permute = list(range(preds.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + preds = torch.permute(*shape_permute).reshape(-1, preds.shape[1]) + target = torch.permute(*shape_permute).reshape(-1, target.shape[1]) # Delete what is in ignore_index, if applicable (and classes don't matter): if ignore_index and reduce in ["micro", "samples"] and preds.shape[1] > 1: From f1ae7b2af062a7608ae61c0c9c1f4792491ef1ad Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 25 Nov 2020 00:59:01 +0100 Subject: [PATCH 17/61] Add back tests --- .../classification/test_stat_scores.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index e69de29bb2d1d..f606b8a8bf8e5 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -0,0 +1,209 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import multilabel_confusion_matrix + +from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics import StatScores +from pytorch_lightning.metrics.functional import stat_scores +from tests.metrics.classification.inputs import ( + _binary_inputs, + _binary_prob_inputs, + _multiclass_inputs, + _multiclass_prob_inputs as _mc_prob, + _multilabel_inputs, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_prob_inputs as _mlmd_prob, + _multilabel_multidim_inputs as _mlmd, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, +) +from tests.metrics.utils import NUM_CLASSES, THRESHOLD, EXTRA_DIM, MetricTester + +torch.manual_seed(42) + + +def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, mdmc_reduce=None): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + sk_preds, sk_target = preds.numpy(), target.numpy() + + if reduce != "macro" and ignore_index: + if preds.shape[1] > 1 and 0 <= ignore_index < preds.shape[1]: + sk_preds = np.delete(sk_preds, ignore_index, 1) + sk_target = np.delete(sk_target, ignore_index, 1) + + if preds.shape[1] == 1 and reduce == "samples": + sk_target = sk_target.T + sk_preds = sk_preds.T + + samplewise = reduce == "samples" and preds.shape[1] != 1 + sk_stats = multilabel_confusion_matrix(sk_target, sk_preds, samplewise=samplewise) + + if preds.shape[1] == 1 and reduce != "samples": + sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] + else: + sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + + if reduce == "micro": + sk_stats = sk_stats.sum(axis=0, keepdims=True) + + sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + + if reduce == "micro": + sk_stats = sk_stats[0] + + if reduce == "macro" and ignore_index: + if preds.shape[1] > 1 and 0 <= ignore_index < preds.shape[1]: + sk_stats[ignore_index, :] = -1 + + return sk_stats + + +def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + + if mdmc_reduce == "global": + shape_permute = list(range(preds.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + preds = torch.permute(*shape_permute).reshape(-1, preds.shape[1]) + target = torch.permute(*shape_permute).reshape(-1, target.shape[1]) + + return _sk_stat_scores(preds, target, reduce, None, False, ignore_index) + else: # mdmc_reduce == "samplewise" + scores = [] + + for i in range(preds.shape[0]): + pred_i = preds[i, ...].T + target_i = target[i, ...].T + scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index) + + scores.append(np.expand_dims(scores_i, 0)) + + return np.concatenate(scores) + + +@pytest.mark.parametrize( + "reduce, mdmc_reduce, num_classes, inputs", + [ + ["unknown", None, None, _binary_inputs], + ["micro", "unknown", None, _binary_inputs], + ["macro", None, None, _binary_inputs], + ["micro", None, None, _mdmc_prob], + ], +) +def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs): + with pytest.raises(ValueError): + stat_scores( + inputs.preds[0], + inputs.target[0], + reduce, + mdmc_reduce, + num_classes=num_classes, + ) + + with pytest.raises(ValueError): + sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes) + sts(inputs.preds[0], inputs.target[0]) + + +@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) +@pytest.mark.parametrize("ignore_index", [None, 1]) +@pytest.mark.parametrize( + "preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_stat_scores, None, 1, None), + (_binary_inputs.preds, _binary_inputs.target, _sk_stat_scores, None, 1, False), + (_ml_prob.preds, _ml_prob.target, _sk_stat_scores, None, NUM_CLASSES, None), + (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_stat_scores, None, NUM_CLASSES, False), + (_mc_prob.preds, _mc_prob.target, _sk_stat_scores, None, NUM_CLASSES, None), + (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_stat_scores, None, NUM_CLASSES, None), + (_mlmd_prob.preds, _mlmd_prob.target, _sk_stat_scores, None, EXTRA_DIM * NUM_CLASSES, None), + (_mlmd.preds, _mlmd.target, _sk_stat_scores, None, EXTRA_DIM * NUM_CLASSES, False), + (_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None), + (_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None), + (_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None), + (_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None), + ], +) +class TestStatScores(MetricTester): + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_stat_scores_class( + self, + ddp, + dist_sync_on_step, + sk_fn, + preds, + target, + reduce, + mdmc_reduce, + num_classes, + is_multiclass, + ignore_index, + ): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=StatScores, + sk_metric=partial( + sk_fn, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "reduce": reduce, + "mdmc_reduce": mdmc_reduce, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + }, + check_dist_sync_on_step=True, + check_batch=True, + ) + + def test_stat_scores_fn( + self, + sk_fn, + preds, + target, + reduce, + mdmc_reduce, + num_classes, + is_multiclass, + ignore_index, + ): + self.run_functional_metric_test( + preds, + target, + metric_functional=stat_scores, + sk_metric=partial( + sk_fn, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": num_classes, + "reduce": reduce, + "mdmc_reduce": mdmc_reduce, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + }, + ) \ No newline at end of file From 04a5066573793328d526ee62b28a4e7b711f7351 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 25 Nov 2020 01:00:09 +0100 Subject: [PATCH 18/61] Add empty newlines --- pytorch_lightning/metrics/classification/precision_recall.py | 2 +- tests/metrics/classification/test_precision_recall.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 9d666eabf53b8..40002af96f5f6 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -211,4 +211,4 @@ def compute(self): if self.average == 'micro': return self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS) elif self.average == 'macro': - return (self.true_positives.float() / (self.actual_positives + METRIC_EPS)).mean() \ No newline at end of file + return (self.true_positives.float() / (self.actual_positives + METRIC_EPS)).mean() diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 8ee114e2e1ae0..967bc60e28307 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -127,4 +127,4 @@ def test_precision_recall( }, check_dist_sync_on_step=False if average == 'macro' else True, check_batch=False if average == 'macro' else True, - ) \ No newline at end of file + ) From 203306303ba559f7ec69ae59189c20d9af8bb560 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 25 Nov 2020 01:06:21 +0100 Subject: [PATCH 19/61] Add precision recall back --- docs/source/metrics.rst | 11 +- .../classification/precision_recall.py | 324 +++++++++++------- .../metrics/functional/__init__.py | 4 +- .../metrics/functional/precision_recall.py | 276 +++++++++++++++ .../classification/test_precision_recall.py | 245 +++++++++---- 5 files changed, 653 insertions(+), 207 deletions(-) create mode 100644 pytorch_lightning/metrics/functional/precision_recall.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index e831bda41647c..40a2f914d981f 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -448,14 +448,7 @@ multiclass_roc [func] precision [func] ~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.precision - :noindex: - - -precision_recall [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall +.. autofunction:: pytorch_lightning.metrics.functional.precision :noindex: @@ -469,7 +462,7 @@ precision_recall_curve [func] recall [func] ~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.recall +.. autofunction:: pytorch_lightning.metrics.functional.recall :noindex: diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 40002af96f5f6..110e4c2614466 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,48 +11,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -import functools -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union -from collections.abc import Mapping, Sequence -from collections import namedtuple +from typing import Optional, Any, Callable import torch -from torch import nn -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS, _input_format_classification_one_hot +from pytorch_lightning.metrics.classification.stat_scores import StatScores +from pytorch_lightning.metrics.functional.precision_recall import _precision_compute, _recall_compute -class Precision(Metric): - """ - Computes the precision metric. - - Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - Forward accepts +class Precision(StatScores): + """Computes the precision score (the ratio ``tp / (tp + fp)``). - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` + The reduction method (how the precision scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + In case where you need to ignore a class in computing the score, anI ``ignore_index`` + parameter is availible. Args: - num_classes: Number of classes in the dataset. - beta: Beta coefficient in the F measure. - threshold: - Threshold value for binary or multi-label logits. default: 0.5 - average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and then takes the mean + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. - multilabel: If predictions are from multilabel classification. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -60,89 +92,141 @@ class Precision(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None Example: - >>> from pytorch_lightning.metrics import Precision - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> precision = Precision(num_classes=3) + >>> from pytorch_lightning.metrics.classification import Precision + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision = Precision(average='macro', num_classes=3) >>> precision(preds, target) - tensor(0.3333) + tensor(0.1667) + >>> precision = Precision(average='micro') + >>> precision(preds, target) + tensor(0.2500) """ + def __init__( self, - num_classes: int = 1, + average: str = "micro", + mdmc_average: Optional[str] = None, threshold: float = 0.5, - average: str = 'micro', - multilabel: bool = False, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) - self.num_classes = num_classes - self.threshold = threshold - self.average = average - self.multilabel = multilabel - - assert self.average in ('micro', 'macro'), \ - "average passed to the function must be either `micro` or `macro`" - - self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") - def update(self, preds: torch.Tensor, target: torch.Tensor): - preds, target = _input_format_classification_one_hot( - self.num_classes, preds, target, self.threshold, self.multilabel - ) - - # multiply because we are counting (1, 1) pair for true positives - self.true_positives += torch.sum(preds * target, dim=1) - self.predicted_positives += torch.sum(preds, dim=1) + self.zero_division = zero_division + self.average = average - def compute(self): - if self.average == 'micro': - return self.true_positives.sum().float() / (self.predicted_positives.sum() + METRIC_EPS) - elif self.average == 'macro': - return (self.true_positives.float() / (self.predicted_positives + METRIC_EPS)).mean() + def compute(self) -> torch.Tensor: + """ + Computes the precision score based on inputs passed in to ``update`` previously. + Return: + The of the returned tensor depends on the ``average`` parameter -class Recall(Metric): - """ - Computes the recall metric. + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + """ - Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. + return _precision_compute( + self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division + ) - Forward accepts - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` +class Recall(StatScores): + """Computes the recall score (the ratio ``tp / (tp + fn)``). - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + In case where you need to ignore a class in computing the score, an ``ignore_index`` + parameter is availible. Args: - num_classes: Number of classes in the dataset. - beta: Beta coefficient in the F measure. - threshold: - Threshold value for binary or multi-label logits. default: 0.5 - average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and then takes the mean + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. - multilabel: If predictions are from multilabel classification. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -150,65 +234,67 @@ class Recall(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None Example: - >>> from pytorch_lightning.metrics import Recall - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> recall = Recall(num_classes=3) + >>> from pytorch_lightning.metrics.classification import Recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> recall = Recall(average='macro', num_classes=3) >>> recall(preds, target) tensor(0.3333) + >>> recall = Recall(average='micro') + >>> recall(preds, target) + tensor(0.2500) """ + def __init__( self, - num_classes: int = 1, + average: str = "micro", + mdmc_average: Optional[str] = None, threshold: float = 0.5, - average: str = 'micro', - multilabel: bool = False, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) - self.num_classes = num_classes - self.threshold = threshold - self.average = average - self.multilabel = multilabel - - assert self.average in ('micro', 'macro'), \ - "average passed to the function must be either `micro` or `macro`" + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") - self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. + self.zero_division = zero_division + self.average = average - Args: - preds: Predictions from model - target: Ground truth values + def compute(self) -> torch.Tensor: """ - preds, target = _input_format_classification_one_hot( - self.num_classes, preds, target, self.threshold, self.multilabel - ) + Computes the recall score based on inputs passed in to ``update`` previously. - # multiply because we are counting (1, 1) pair for true positives - self.true_positives += torch.sum(preds * target, dim=1) - self.actual_positives += torch.sum(target, dim=1) + Return: + The of the returned tensor depends on the ``average`` parameter - def compute(self): + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes """ - Computes accuracy over state. - """ - if self.average == 'micro': - return self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS) - elif self.average == 'macro': - return (self.true_positives.float() / (self.actual_positives + METRIC_EPS)).mean() + + return _recall_compute(self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 7582a5f3d38ac..b3af66a129283 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -19,10 +19,7 @@ multiclass_precision_recall_curve, multiclass_roc, multiclass_auroc, - precision, - precision_recall, precision_recall_curve, - recall, roc, to_categorical, to_onehot, @@ -44,3 +41,4 @@ from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 from pytorch_lightning.metrics.functional.stat_scores import stat_scores +from pytorch_lightning.metrics.functional.precision_recall import precision, recall diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py new file mode 100644 index 0000000000000..e7fd2cca3fca8 --- /dev/null +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -0,0 +1,276 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from pytorch_lightning.metrics.functional.reduction import _reduce_scores +from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update + + +def _precision_compute( + tp: torch.Tensor, + fp: torch.Tensor, + tn: torch.Tensor, + fn: torch.Tensor, + average: str, + mdmc_average: Optional[str], + zero_division: int, +) -> torch.Tensor: + return _reduce_scores( + numerator=tp, + denominator=tp + fp, + weights=tp + fn, + average=average, + mdmc_average=mdmc_average, + zero_division=zero_division, + ) + + +def precision( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, +) -> torch.Tensor: + """Computes the precision score (the ratio ``tp / (tp + fp)``). + + The reduction method (how the precision scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + In case where you need to ignore a class in computing the score, anI ``ignore_index`` + parameter is availible. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. + + Return: + The of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + + Example: + + >>> from pytorch_lightning.metrics.functional import precision + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision(preds, target, average='macro', num_classes=3) + tensor(0.1667) + >>> precision(preds, target, average='micro') + tensor(0.2500) + + """ + + reduce = "macro" if average in ["weighted", "none", None] else average + + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") + + tp, fp, tn, fn = _stat_scores_update( + preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index + ) + + return _precision_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) + + +def _recall_compute( + tp: torch.Tensor, + fp: torch.Tensor, + tn: torch.Tensor, + fn: torch.Tensor, + average: str, + mdmc_average: Optional[str], + zero_division: int, +) -> torch.Tensor: + return _reduce_scores( + numerator=tp, + denominator=tp + fn, + weights=tp + fn, + average=average, + mdmc_average=mdmc_average, + zero_division=zero_division, + ) + + +def recall( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + threshold: float = 0.5, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, + zero_division: int = 0, +) -> torch.Tensor: + """Computes the recall score (the ratio ``tp / (tp + fn)``). + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + In case where you need to ignore a class in computing the score, an ``ignore_index`` + parameter is availible. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics + (tp, fp, tn, fn) accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + is_multiclass: + If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as + binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs + as multi-class or multi-dim multi-class with 2 classes, respectively. + Defaults to ``None``, which treats inputs as they appear. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that + is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. + + If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class + will be returned as ``nan`` (to not break the indexing of other labels). + zero_division: + Score to use for classes/samples, whose score has 0 in the denominator. Has to be either + 0 [default] or 1. + + Return: + The of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + + Example: + + >>> from pytorch_lightning.metrics.functional import recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> recall(preds, target, average='macro', num_classes=3) + tensor(0.3333) + >>> recall(preds, target, average='micro') + tensor(0.2500) + + """ + + reduce = "macro" if average in ["weighted", "none", None] else average + + if zero_division not in [0, 1]: + raise ValueError("zero_division has to be either 0 or 1") + + tp, fp, tn, fn = _stat_scores_update( + preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index + ) + + return _recall_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 967bc60e28307..f43a7f41f0bec 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,126 +5,219 @@ import torch from sklearn.metrics import precision_score, recall_score +from pytorch_lightning.metrics.classification.utils import _input_format_classification from pytorch_lightning.metrics import Precision, Recall +from pytorch_lightning.metrics.functional import precision, recall from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, _multiclass_inputs, - _multiclass_prob_inputs, - _multidim_multiclass_inputs, - _multidim_multiclass_prob_inputs, - _multilabel_inputs, - _multilabel_prob_inputs, + _multiclass_prob_inputs as _mc_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, + _multilabel_inputs as _ml, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_prob_inputs as _mlmd_prob, + _multilabel_multidim_inputs as _mlmd, ) -from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester +from tests.metrics.utils import EXTRA_DIM, NUM_CLASSES, THRESHOLD, MetricTester torch.manual_seed(42) -def _sk_prec_recall_binary_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +def _sk_prec_recall( + preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average=None +): + if average == "none": + average = None + if num_classes == 1: + average = "binary" - return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') + labels = list(range(num_classes)) + try: + labels.remove(ignore_index) + except ValueError: + pass + sk_preds, sk_target, _ = _input_format_classification( + preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() -def _sk_prec_recall_binary(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=zero_division, labels=labels) - return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') + if len(labels) != num_classes and not average: + sk_scores = np.insert(sk_scores, ignore_index, np.nan) + return sk_scores -def _sk_prec_recall_multilabel_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1, NUM_CLASSES).numpy() - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) +def _sk_prec_recall_mdmc( + preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average +): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + if mdmc_average == "global": + preds = torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) + target = torch.movedim(target, 1, -1).reshape(-1, target.shape[1]) -def _sk_prec_recall_multilabel(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1, NUM_CLASSES).numpy() - sk_target = target.view(-1, NUM_CLASSES).numpy() + return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, zero_division, ignore_index) + else: # mdmc_average == "samplewise" + scores = [] - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + for i in range(preds.shape[0]): + pred_i = preds[i, ...].T + target_i = target[i, ...].T + scores_i = _sk_prec_recall( + pred_i, target_i, sk_fn, num_classes, average, False, zero_division, ignore_index + ) + scores.append(np.expand_dims(scores_i, 0)) -def _sk_prec_recall_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() + return np.concatenate(scores).mean() - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) +@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) +def test_wrong_params(metric, fn_metric): + with pytest.raises(ValueError): + metric(zero_division=None) -def _sk_prec_recall_multiclass(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + with pytest.raises(ValueError): + fn_metric(_binary_inputs.preds[0], _binary_inputs.target[0], zero_division=None) - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) +###################################################################################### +# Testing for MDMC inputs is partially skipped, because some cases appear where +# (with mdmc_average1 =! None, ignore_index=1, average='weighted') a sample in +# target contains only labels "1" - and as we are ignoring this index, weights of +# all labels will be zero. In this special edge case, sklearn handles the situation +# differently for each metric (recall, precision, fscore), which breaks ours handling +# everything in _reduce_scores (where the return value is 0 in this situation). +###################################################################################### -def _sk_prec_recall_multidim_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) - - -@pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("dist_sync_on_step", [True, False]) -@pytest.mark.parametrize("average", ['micro', 'macro']) @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes, multilabel", - [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False), - (_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True), - (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False), - (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False), - ( - _multidim_multiclass_prob_inputs.preds, - _multidim_multiclass_prob_inputs.target, - _sk_prec_recall_multidim_multiclass_prob, - NUM_CLASSES, - False, - ), - ( - _multidim_multiclass_inputs.preds, - _multidim_multiclass_inputs.target, - _sk_prec_recall_multidim_multiclass, - NUM_CLASSES, - False, - ), - ], + "metric_class, sk_fn, metric_fn", [(Precision, precision_score, precision), (Recall, recall_score, recall)] ) +@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) +@pytest.mark.parametrize("zero_division", [0, 1]) +@pytest.mark.parametrize("ignore_index", [None, 1]) @pytest.mark.parametrize( - "metric_class, sk_fn", [(Precision, precision_score), (Recall, recall_score)], + "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, 1, None, None, _sk_prec_recall), + (_binary_inputs.preds, _binary_inputs.target, 1, False, None, _sk_prec_recall), + (_ml_prob.preds, _ml_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_ml.preds, _ml.target, NUM_CLASSES, False, None, _sk_prec_recall), + (_mc_prob.preds, _mc_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_multiclass_inputs.preds, _multiclass_inputs.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_mlmd_prob.preds, _mlmd_prob.target, EXTRA_DIM * NUM_CLASSES, None, None, _sk_prec_recall), + (_mlmd.preds, _mlmd.target, EXTRA_DIM * NUM_CLASSES, False, None, _sk_prec_recall), + (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), + (_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), + (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc), + (_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc), + ], ) class TestPrecisionRecall(MetricTester): - def test_precision_recall( - self, ddp, dist_sync_on_step, preds, target, sk_metric, metric_class, sk_fn, num_classes, multilabel, average + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_precision_recall_class( + self, + ddp, + dist_sync_on_step, + preds, + target, + sk_wrapper, + metric_class, + metric_fn, + sk_fn, + is_multiclass, + num_classes, + average, + mdmc_average, + zero_division, + ignore_index, ): + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn, average=average), + sk_metric=partial( + sk_wrapper, + sk_fn=sk_fn, + average=average, + num_classes=num_classes, + is_multiclass=is_multiclass, + zero_division=zero_division, + ignore_index=ignore_index, + mdmc_average=mdmc_average, + ), dist_sync_on_step=dist_sync_on_step, metric_args={ "num_classes": num_classes, "average": average, - "multilabel": multilabel, "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "zero_division": zero_division, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + check_dist_sync_on_step=True, + check_batch=True, + ) + + def test_precision_recall_fn( + self, + preds, + target, + sk_wrapper, + metric_class, + metric_fn, + sk_fn, + is_multiclass, + num_classes, + average, + mdmc_average, + zero_division, + ignore_index, + ): + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_functional_metric_test( + preds, + target, + metric_functional=metric_fn, + sk_metric=partial( + sk_wrapper, + sk_fn=sk_fn, + average=average, + num_classes=num_classes, + is_multiclass=is_multiclass, + zero_division=zero_division, + ignore_index=ignore_index, + mdmc_average=mdmc_average, + ), + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "zero_division": zero_division, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, }, - check_dist_sync_on_step=False if average == 'macro' else True, - check_batch=False if average == 'macro' else True, ) From 9dc7bea917dbe974f76af02a2d09eca8462ed922 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 25 Nov 2020 01:06:50 +0100 Subject: [PATCH 20/61] Add empty line --- tests/metrics/classification/test_stat_scores.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index f606b8a8bf8e5..a1cf7ab4dcf6e 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -206,4 +206,4 @@ def test_stat_scores_fn( "is_multiclass": is_multiclass, "ignore_index": ignore_index, }, - ) \ No newline at end of file + ) From a9640f6c77b76fb7f6236e3e872e926d0fe7fb93 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 25 Nov 2020 01:22:49 +0100 Subject: [PATCH 21/61] Fix permute --- pytorch_lightning/metrics/functional/stat_scores.py | 4 ++-- tests/metrics/classification/test_stat_scores.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 1f5e3b50bdf7a..6e4a74804a9ed 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -108,8 +108,8 @@ def _stat_scores_update( shape_permute[1] = shape_permute[-1] shape_permute[2:] = range(1, len(shape_permute) - 1) - preds = torch.permute(*shape_permute).reshape(-1, preds.shape[1]) - target = torch.permute(*shape_permute).reshape(-1, target.shape[1]) + preds = preds.permute(*shape_permute).reshape(-1, preds.shape[1]) + target = target.permute(*shape_permute).reshape(-1, target.shape[1]) # Delete what is in ignore_index, if applicable (and classes don't matter): if ignore_index and reduce in ["micro", "samples"] and preds.shape[1] > 1: diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index a1cf7ab4dcf6e..f90150c7ce912 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -73,8 +73,8 @@ def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_mul shape_permute[1] = shape_permute[-1] shape_permute[2:] = range(1, len(shape_permute) - 1) - preds = torch.permute(*shape_permute).reshape(-1, preds.shape[1]) - target = torch.permute(*shape_permute).reshape(-1, target.shape[1]) + preds = preds.permute(*shape_permute).reshape(-1, preds.shape[1]) + target = target.permute(*shape_permute).reshape(-1, target.shape[1]) return _sk_stat_scores(preds, target, reduce, None, False, ignore_index) else: # mdmc_reduce == "samplewise" From 692392cca08196925f11054dff691feb4d00e4ef Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 02:02:57 +0100 Subject: [PATCH 22/61] Fix some issues with old versions of PyTorch --- pytorch_lightning/metrics/classification/stat_scores.py | 4 ++-- pytorch_lightning/metrics/functional/stat_scores.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index bfb4058d7e898..2f674a79755e4 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -153,9 +153,9 @@ def __init__( elif reduce == "macro": default, reduce_fn = torch.zeros((num_classes,), dtype=torch.int), "sum" elif reduce == "samples": - default, reduce_fn = torch.empty(0), _dim_zero_cat_and_put_back + default, reduce_fn = torch.empty(0, dtype=torch.int), _dim_zero_cat_and_put_back else: - default, reduce_fn = torch.empty(0), _dim_zero_cat_and_put_back + default, reduce_fn = torch.empty(0, dtype=torch.int), _dim_zero_cat_and_put_back for s in ("tp", "fp", "tn", "fn"): self.add_state(s, default=default.detach().clone(), dist_reduce_fx=reduce_fn) diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 6e4a74804a9ed..218710273a5d8 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -148,7 +148,7 @@ def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, f outputs = torch.cat(outputs, -1).long() # To standardzie ignore_index statistics as -1 - outputs = torch.where(outputs < 0, -1, outputs) + outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs) return outputs From a04a71ea195c601d59b22c7295c6d1389d7155fe Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:39:18 +0100 Subject: [PATCH 23/61] Style changes in error messages --- .../metrics/classification/utils.py | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index f58e61dcd3127..398acbf2bc5fc 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -62,39 +62,40 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("target has to be an integer tensor") + raise ValueError("`target` has to be an integer tensor") elif target.min() < 0: - raise ValueError("target has to be a non-negative tensor") + raise ValueError("`target` has to be a non-negative tensor") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if preds are integers, they have to be non-negative") + raise ValueError("if `preds` are integers, they have to be non-negative") if not preds.shape[0] == target.shape[0]: - raise ValueError("preds and target should have the same first dimension.") + raise ValueError("`preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("preds should be probabilities, but values were detected outside of [0,1] range") + raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range") if threshold > 1 or threshold < 0: raise ValueError("Threshold should be a probability in [0,1]") if is_multiclass is False and target.max() > 1: - raise ValueError("If you set is_multiclass=False, then target should not exceed 1.") + raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") + raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") # Check that shape/types fall into one of the cases if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError( - "preds and targets should have the same shape", - f" got preds shape = {preds.shape} and target shape = {target.shape}.", + "`preds` and `target` should have the same shape", + f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: - raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") + raise ValueError("if `preds` and `target` are of shape (N, ...)" + " and `preds` are floats, `target` should be binary") # Get the case if preds.ndim == 1 and preds_float: @@ -110,12 +111,12 @@ def _check_classification_inputs( elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if preds have one dimension more than target, preds should be a float tensor") + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor") if not preds.shape[:-1] == target.shape: if preds.shape[2:] != target.shape[1:]: raise ValueError( - "if preds if preds have one dimension more than target, the shape of preds should be" - "either of shape (N, C, ...) or (N, ..., C), and of targets of shape (N, ...)" + "if `preds` have one dimension more than `target`, the shape of `preds` should be" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)" ) extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] @@ -126,64 +127,64 @@ def _check_classification_inputs( case = "multi-dim multi-class" else: raise ValueError( - "preds and target should both have the (same) shape (N, ...), or target (N, ...)" - " and preds (N, C, ...) or (N, ..., C)" + "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" + " and `preds` (N, C, ...) or (N, ..., C)" ) if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: raise ValueError( - "You have set is_multiclass=False, but have more than 2 classes in your data," - " based on the C dimension of preds." + "You have set `is_multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." ) # Check that num_classes is consistent if not num_classes: if preds.shape != target.shape and target.max() >= extra_dim_size: - raise ValueError("The highest label in targets should be smaller than the size of C dimension") + raise ValueError("The highest label in `target` should be smaller than the size of C dimension") else: if case == "binary": if num_classes > 2: - raise ValueError("Your data is binary, but num_classes is larger than 2.") + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") elif num_classes == 2 and not is_multiclass: raise ValueError( - "Your data is binary and num_classes=2, but is_multiclass is not True." - "Set it to True if you want to transform binary data to multi-class format." + "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." ) elif num_classes == 1 and is_multiclass: raise ValueError( - "You have binary data and have set is_multiclass=True, but num_classes is 1." - "Either leave is_multiclass unset or set it to 2 to transform binary data to multi-class format." + "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." + " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." ) elif "multi-class" in case: if num_classes == 1 and is_multiclass is not False: raise ValueError( - "You have set num_classes=1, but predictions are integers." - "If you want to convert (multi-dimensional) multi-class data with 2 classes" - "to binary/multi-label, set is_multiclass=False." + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `is_multiclass=False`." ) elif num_classes > 1: if is_multiclass is False: if implied_classes != num_classes: raise ValueError( - "You have set is_multiclass=False, but the implied number of classes " - "(from shape of inputs) does not match num_classes. If you are trying to" - "transform multi-dim multi-class data with 2 classes to multi-label, num_classes" - "should be either None or the product of the size of extra dimensions (...)." - "See Input Types in Metrics documentation." + "You have set `is_multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." ) if num_classes <= target.max(): - raise ValueError("The highest label in targets should be smaller than num_classes") + raise ValueError("The highest label in `target` should be smaller than `num_classes`") if num_classes <= preds.max(): - raise ValueError("The highest label in preds should be smaller than num_classes") + raise ValueError("The highest label in `preds` should be smaller than `num_classes`") if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of preds does not match num_classes") + raise ValueError("The size of C dimension of `preds` does not match `num_classes`") elif case == "multi-label": if is_multiclass and num_classes != 2: raise ValueError( - "Your have set is_multiclass=True, but num_classes is not equal to 2." - "If you are trying to transform multi-label data to 2 class multi-dimensional" - "multi-class, you should set num_classes to either 2 or None." + "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." ) if not is_multiclass and num_classes != implied_classes: raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") @@ -192,8 +193,8 @@ def _check_classification_inputs( if top_k > 1: if preds.shape == target.shape: raise ValueError( - "You have set top_k above 1, but your data is not (multi-dimensional) multi-class" - "with probability predictions." + "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" + " with probability predictions." ) From eaac5d74cffb2980450a95e15e9ce13f13802605 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:42:52 +0100 Subject: [PATCH 24/61] More error message style improvements --- .../metrics/classification/utils.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 398acbf2bc5fc..54bcd840f3621 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -62,23 +62,23 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("`target` has to be an integer tensor") + raise ValueError("`target` has to be an integer tensor.") elif target.min() < 0: - raise ValueError("`target` has to be a non-negative tensor") + raise ValueError("`target` has to be a non-negative tensor.") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if `preds` are integers, they have to be non-negative") + raise ValueError("if `preds` are integers, they have to be non-negative.") if not preds.shape[0] == target.shape[0]: raise ValueError("`preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range") + raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range.") if threshold > 1 or threshold < 0: - raise ValueError("Threshold should be a probability in [0,1]") + raise ValueError("`threshold` should be a probability in [0,1].") if is_multiclass is False and target.max() > 1: raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") @@ -94,8 +94,9 @@ def _check_classification_inputs( f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: - raise ValueError("if `preds` and `target` are of shape (N, ...)" - " and `preds` are floats, `target` should be binary") + raise ValueError( + "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + ) # Get the case if preds.ndim == 1 and preds_float: @@ -111,12 +112,12 @@ def _check_classification_inputs( elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor") + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") if not preds.shape[:-1] == target.shape: if preds.shape[2:] != target.shape[1:]: raise ValueError( "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." ) extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] @@ -128,7 +129,7 @@ def _check_classification_inputs( else: raise ValueError( "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)" + " and `preds` (N, C, ...) or (N, ..., C)." ) if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: @@ -140,7 +141,7 @@ def _check_classification_inputs( # Check that num_classes is consistent if not num_classes: if preds.shape != target.shape and target.max() >= extra_dim_size: - raise ValueError("The highest label in `target` should be smaller than the size of C dimension") + raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") else: if case == "binary": if num_classes > 2: @@ -173,11 +174,11 @@ def _check_classification_inputs( " See Input Types in Metrics documentation." ) if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`") + raise ValueError("The highest label in `target` should be smaller than `num_classes`.") if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`") + raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`") + raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") elif case == "multi-label": if is_multiclass and num_classes != 2: From c1108f0cf81359008f5b462c3ba0bd851fd50a0d Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:44:28 +0100 Subject: [PATCH 25/61] Fix typo in docs --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 407b64d3d2948..831ac67922114 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -230,7 +230,7 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. -When predictions or targets are integers, it is assumed that class labels start at , i.e. +When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types .. code-block:: python From 277769b7c4e3aa84a83b3e12eb724ea3318e44c0 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:50:02 +0100 Subject: [PATCH 26/61] Add more descriptive variable names in utils --- pytorch_lightning/metrics/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 0f71c531fe6ac..b7f8d492ce01d 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -37,14 +37,14 @@ def _flatten(x): def to_onehot( - tensor: torch.Tensor, + label_tensor: torch.Tensor, num_classes: int, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] + label_tensor: dense label tensor, with shape [N, d1, d2, ...] num_classes: number of classes C Output: @@ -57,18 +57,18 @@ def to_onehot( [0, 0, 1, 0], [0, 0, 0, 1]]) """ - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape + dtype, device, shape = label_tensor.dtype, label_tensor.device, label_tensor.shape tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) -def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: +def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ Convert a probability tensor to binary by selecting top-k highest entries. Args: - tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the position defined by the ``dim`` argument topk: number of highest entries to turn into 1s dim: dimension on which to compare entries @@ -82,8 +82,8 @@ def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tens tensor([[0, 1, 1], [1, 1, 0]], dtype=torch.int32) """ - zeros = torch.zeros_like(tensor, device=tensor.device) - topk_tensor = zeros.scatter(1, tensor.topk(k=topk, dim=dim).indices, 1.0) + zeros = torch.zeros_like(prob_tensor, device=prob_tensor.device) + topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() From 484929861c54922cf8f748eaf677efae4cff5fcb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:59:31 +0100 Subject: [PATCH 27/61] Change internal var names --- tests/metrics/classification/test_inputs.py | 118 ++++++++++---------- 1 file changed, 57 insertions(+), 61 deletions(-) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 19828dc07eb34..6ec6c8dbbc498 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -43,47 +43,43 @@ T = torch.Tensor -def idn(x): +def _idn(x): return x -def usq(x): +def _usq(x): return x.unsqueeze(-1) -def toint(x): - return x.int() - - -def thrs(x): +def _thrs(x): return x >= THRESHOLD -def rshp1(x): +def _rshp1(x): return x.reshape(x.shape[0], -1) -def rshp2(x): +def _rshp2(x): return x.reshape(x.shape[0], x.shape[1], -1) -def onehot(x): +def _onehot(x): return to_onehot(x, NUM_CLASSES) -def onehot2(x): +def _onehot2(x): return to_onehot(x, 2) -def top1(x): +def _top1(x): return select_topk(x, 1) -def top2(x): +def _top2(x): return select_topk(x, 2) -def mvdim(x): +def _mvdim(x): """ Equivalent of torch.movedim(x, -1, 1) """ shape_permute = list(range(x.ndim)) shape_permute[1] = shape_permute[-1] @@ -93,44 +89,44 @@ def mvdim(x): # To avoid ugly black line wrapping -def ml_preds_tr(x): - return rshp1(toint(thrs(x))) +def _ml_preds_tr(x): + return _rshp1(_thrs(x).int()) -def onehot_rshp1(x): - return onehot(rshp1(x)) +def _onehot_rshp1(x): + return _onehot(_rshp1(x)) -def onehot2_rshp1(x): - return onehot2(rshp1(x)) +def _onehot2_rshp1(x): + return _onehot2(_rshp1(x)) -def top1_rshp2(x): - return top1(rshp2(x)) +def _top1_rshp2(x): + return _top1(_rshp2(x)) -def top2_rshp2(x): - return top2(rshp2(x)) +def _top2_rshp2(x): + return _top2(_rshp2(x)) -def mdmc1_top1_tr(x): - return top1(rshp2(mvdim(x))) +def _mdmc1_top1_tr(x): + return _top1(_rshp2(_mvdim(x))) -def mdmc1_top2_tr(x): - return top2(rshp2(mvdim(x))) +def _mdmc1_top2_tr(x): + return _top2(_rshp2(_mvdim(x))) -def probs_to_mc_preds_tr(x): - return toint(onehot2(thrs(x))) +def _probs_to_mc_preds_tr(x): + return _onehot2(_thrs(x)).int() -def mlmd_prob_to_mc_preds_tr(x): - return onehot2(rshp1(toint(thrs(x)))) +def _mlmd_prob_to_mc_preds_tr(x): + return _onehot2(_rshp1(_thrs(x).int())) -def mdmc_prob_to_ml_preds_tr(x): - return top1(mvdim(x))[:, 1] +def _mdmc_prob_to__ml_preds_tr(x): + return _top1(_mvdim(x))[:, 1] ######################## @@ -143,44 +139,44 @@ def mdmc_prob_to_ml_preds_tr(x): [ ############################# # Test usual expected cases - (_bin, THRESHOLD, None, False, 1, "multi-class", usq, usq), - (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: usq(toint(thrs(x))), usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: toint(thrs(x)), idn), - (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", idn, idn), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", ml_preds_tr, rshp1), - (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", rshp1, rshp1), - (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", onehot, onehot), - (_mc_prob, THRESHOLD, None, None, 1, "multi-class", top1, onehot), - (_mc_prob, THRESHOLD, None, None, 2, "multi-class", top2, onehot), - (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", onehot, onehot), - (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot), - (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot_rshp1), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot_rshp1), + (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x).int()), _usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x).int(), _idn), + (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", _onehot, _onehot), + (_mc_prob, THRESHOLD, None, None, 1, "multi-class", _top1, _onehot), + (_mc_prob, THRESHOLD, None, None, 2, "multi-class", _top2, _onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), # Test with C dim in last place - (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot), - (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot_rshp1), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot_rshp1), + (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot), + (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot_rshp1), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass - (_bin, THRESHOLD, None, None, 1, "multi-class", onehot2, onehot2), + (_bin, THRESHOLD, None, None, 1, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, THRESHOLD, None, True, 1, "binary", probs_to_mc_preds_tr, onehot2), + (_bin_prob, THRESHOLD, None, True, 1, "binary", _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2, onehot2), + (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, THRESHOLD, None, True, 1, "multi-label", probs_to_mc_preds_tr, onehot2), + (_ml_prob, THRESHOLD, None, True, 1, "multi-label", _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2_rshp1, onehot2_rshp1), + (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", mlmd_prob_to_mc_preds_tr, onehot2_rshp1), + (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: top1(x)[:, [1]], usq), + (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: top1(x)[:, 1], idn), - (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", mdmc_prob_to_ml_preds_tr, idn), + (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", _mdmc_prob_to__ml_preds_tr, _idn), ], ) def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): From cac0b85fa0193800b2018e039e81f8fea30fcc83 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 30 Dec 2020 21:14:43 +0100 Subject: [PATCH 28/61] Revert unwanted changes --- pytorch_lightning/metrics/classification/accuracy.py | 7 ------- pytorch_lightning/metrics/functional/reduction.py | 3 --- tests/metrics/classification/test_inputs.py | 4 ++++ tests/metrics/utils.py | 1 + 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 0bdd225d6bd07..e50b948f389f3 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -29,12 +29,6 @@ class Accuracy(Metric): Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - This metric generalizes to subset accuracy for multilabel data, and similarly for - multi-dimensional multi-class data: for the sample to be counted as correct, the the - class has to be correctly predicted across all extra dimension for each sample in the - ``N`` dimension. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` - is this is not what you want. - For multi-class and multi-dimensional multi-class data with probability predictions, the parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the top-K highest probability items are considered to find the correct label. @@ -47,7 +41,6 @@ class has to be correctly predicted across all extra dimension for each sample i Accepts all input types listed in :ref:`metrics:Input types`. Args: - threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py index 36ed715568a47..c116b16d363a9 100644 --- a/pytorch_lightning/metrics/functional/reduction.py +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - -import numpy as np import torch from pytorch_lightning.metrics.utils import reduce as __reduce, class_reduce as __cr diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 51e92b538a811..8cfe1dd46ec50 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -175,7 +175,11 @@ def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_p preds=inputs.preds[0][[0], ...], target=inputs.target[0][[0], ...], threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, ) + assert mode == exp_mode assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 1673ec3e2f2ad..4bd6608ce3fcf 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -170,6 +170,7 @@ def setup_class(self): """Setup the metric class. This will spawn the pool of workers that are used for metric testing and setup_ddp """ + self.poolSize = NUM_PROCESSES self.pool = Pool(processes=self.poolSize) self.pool.starmap(setup_ddp, [(rank, self.poolSize) for rank in range(self.poolSize)]) From d043384aeb61a6a3b1b59e2be631bfb758d55e3f Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 30 Dec 2020 21:15:55 +0100 Subject: [PATCH 29/61] Revert unwanted changes pt 2 --- .../metrics/classification/utils.py | 392 ------------------ 1 file changed, 392 deletions(-) delete mode 100644 pytorch_lightning/metrics/classification/utils.py diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py deleted file mode 100644 index 54bcd840f3621..0000000000000 --- a/pytorch_lightning/metrics/classification/utils.py +++ /dev/null @@ -1,392 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Tuple, Optional - -import torch - -from pytorch_lightning.metrics.utils import to_onehot, select_topk - - -def _check_classification_inputs( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float, - num_classes: Optional[int] = None, - is_multiclass: bool = False, - top_k: int = 1, -) -> None: - """Performs error checking on inputs for classification. - - This ensures that preds and target take one of the shape/type combinations that are - specified in ``_input_format_classification`` docstring. It also checks the cases of - over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class - cases) that there are only up to 2 distinct labels. - - In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. - - When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, - multi-label, ...), and that, if availible, the implied number of classes in the ``C`` - dimension is consistent with it (as well as that max label in target is smaller than it). - - When ``num_classes`` is not specified in these cases, consistency of the highest target - value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. - - If ``top_k`` is larger than one, then an error is raised if the inputs are not (multi-dim) - multi-class with probability predictions. - - Preds and target tensors are expected to be squeezed already - all dimensions should be - greater than 1, except perhaps the first one (N). - - Args: - preds: tensor with predictions - target: tensor with ground truth labels, always integers - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - num_classes: number of classes - is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim - multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim - multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. - Defaults to None, which treats inputs as they appear. - """ - - if target.is_floating_point(): - raise ValueError("`target` has to be an integer tensor.") - elif target.min() < 0: - raise ValueError("`target` has to be a non-negative tensor.") - - preds_float = preds.is_floating_point() - if not preds_float and preds.min() < 0: - raise ValueError("if `preds` are integers, they have to be non-negative.") - - if not preds.shape[0] == target.shape[0]: - raise ValueError("`preds` and `target` should have the same first dimension.") - - if preds_float: - if preds.min() < 0 or preds.max() > 1: - raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range.") - - if threshold > 1 or threshold < 0: - raise ValueError("`threshold` should be a probability in [0,1].") - - if is_multiclass is False and target.max() > 1: - raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") - - if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") - - # Check that shape/types fall into one of the cases - if preds.ndim == target.ndim: - if preds.shape != target.shape: - raise ValueError( - "`preds` and `target` should have the same shape", - f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", - ) - if preds_float and target.max() > 1: - raise ValueError( - "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." - ) - - # Get the case - if preds.ndim == 1 and preds_float: - case = "binary" - elif preds.ndim == 1 and not preds_float: - case = "multi-class" - elif preds.ndim > 1 and preds_float: - case = "multi-label" - else: - case = "multi-dim multi-class" - - implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) - - elif preds.ndim == target.ndim + 1: - if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if not preds.shape[:-1] == target.shape: - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." - ) - - extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] - - if preds.ndim == 2: - case = "multi-class" - else: - case = "multi-dim multi-class" - else: - raise ValueError( - "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)." - ) - - if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: - raise ValueError( - "You have set `is_multiclass=False`, but have more than 2 classes in your data," - " based on the C dimension of `preds`." - ) - - # Check that num_classes is consistent - if not num_classes: - if preds.shape != target.shape and target.max() >= extra_dim_size: - raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") - else: - if case == "binary": - if num_classes > 2: - raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - elif num_classes == 2 and not is_multiclass: - raise ValueError( - "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." - " Set it to True if you want to transform binary data to multi-class format." - ) - elif num_classes == 1 and is_multiclass: - raise ValueError( - "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." - ) - elif "multi-class" in case: - if num_classes == 1 and is_multiclass is not False: - raise ValueError( - "You have set `num_classes=1`, but predictions are integers." - " If you want to convert (multi-dimensional) multi-class data with 2 classes" - " to binary/multi-label, set `is_multiclass=False`." - ) - elif num_classes > 1: - if is_multiclass is False: - if implied_classes != num_classes: - raise ValueError( - "You have set `is_multiclass=False`, but the implied number of classes " - " (from shape of inputs) does not match `num_classes`. If you are trying to" - " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" - " should be either None or the product of the size of extra dimensions (...)." - " See Input Types in Metrics documentation." - ) - if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`.") - if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") - if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") - - elif case == "multi-label": - if is_multiclass and num_classes != 2: - raise ValueError( - "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." - " If you are trying to transform multi-label data to 2 class multi-dimensional" - " multi-class, you should set `num_classes` to either 2 or None." - ) - if not is_multiclass and num_classes != implied_classes: - raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") - - # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities - if top_k > 1: - if preds.shape == target.shape: - raise ValueError( - "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" - " with probability predictions." - ) - - -def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - top_k: int = 1, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor, str]: - """Convert preds and target tensors into common format. - - Preds and targets are supposed to fall into one of these categories (and are - validated to make sure this is the case): - - * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) - * Both preds and target are of shape ``(N,)``, and target is binary, while preds - are a float (binary) - * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and - is integer (multi-class) - * preds and target are of shape ``(N, ...)``, target is binary and preds is a float - (multi-label) - * preds are of shape ``(N, ..., C)`` or ``(N, C, ...)`` and are floats, target is of - shape ``(N, ...)`` and is integer (multi-dimensional multi-class) - * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional - multi-class) - - To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. - - The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` - of ``(N, C, X)``, the details for each case are described below. The function also returns - a ``mode`` string, which describes which of the above cases the inputs belonged to - regardless - of whether this was "overridden" by other settings (like ``is_multiclass``). - - In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed - into a binary tensor (elements become 1 if the probability is greater than or equal to - ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds - become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to - preds first. - - In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets - by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original - shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be - returned as ``(N,1)`` tensor. - - In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with - preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening - all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as - ``(N, 2, C)``, by an equivalent transformation as in the binary case. - - In multi-dimensional multi-class case, normally both target and preds are returned as - ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and - ``C``. The transformations performed here are equivalent to the multi-class case. However, if - ``is_multiclass=False`` (and there are up to two classes), then the data is returned as - ``(N, X)`` binary tensors (multi-label). - - Also, in multi-dimensional multi-class case, if the position of the ``C`` - dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a - ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. - If this is not the case, you should move it from the last to second place using - ``torch.movedim(preds, -1, 1)``, or using ``preds.permute``, if you are using an older - version of Pytorch. - - Note that where a one-hot transformation needs to be performed and the number of classes - is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be - equal to ``num_classes``, if it is given, or the maximum label value in preds and - target. - - Args: - preds: tensor with predictions - target: tensor with ground truth labels, always integers - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - num_classes: number of classes - top_k: number of highest probability entries for each sample to convert to 1s, relevant - only for (multi-dimensional) multi-class cases. - is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim - multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim - multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. - Defaults to None, which treats inputs as they appear. - - Returns: - preds: binary tensor of shape (N, C) or (N, C, X) - target: binary tensor of shape (N, C) or (N, C, X) - """ - preds, target = preds.clone().detach(), target.clone().detach() - - # Remove excess dimensions - if preds.shape[0] == 1: - preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) - else: - preds, target = preds.squeeze(), target.squeeze() - - _check_classification_inputs( - preds, - target, - threshold=threshold, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) - - preds_float = preds.is_floating_point() - - if preds.ndim == target.ndim == 1 and preds_float: - mode = "binary" - preds = (preds >= threshold).int() - - if is_multiclass: - target = to_onehot(target, 2) - preds = to_onehot(preds, 2) - else: - preds = preds.unsqueeze(-1) - target = target.unsqueeze(-1) - - elif preds.ndim == target.ndim and preds_float: - mode = "multi-label" - preds = (preds >= threshold).int() - - if is_multiclass: - preds = to_onehot(preds, 2).reshape(preds.shape[0], 2, -1) - target = to_onehot(target, 2).reshape(target.shape[0], 2, -1) - else: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - - elif preds.ndim == target.ndim + 1 == 2: - mode = "multi-class" - if not num_classes: - num_classes = preds.shape[1] - - target = to_onehot(target, num_classes) - preds = select_topk(preds, top_k) - - # If is_multiclass=False, force to binary - if is_multiclass is False: - target = target[:, [1]] - preds = preds[:, [1]] - - elif preds.ndim == target.ndim == 1 and not preds_float: - mode = "multi-class" - - if not num_classes: - num_classes = max(preds.max(), target.max()) + 1 - - # If is_multiclass=False, force to binary - if is_multiclass is False: - preds = preds.unsqueeze(1) - target = target.unsqueeze(1) - else: - preds = to_onehot(preds, num_classes) - target = to_onehot(target, num_classes) - - # Multi-dim multi-class (N, ...) with integers - elif preds.shape == target.shape and not preds_float: - mode = "multi-dim multi-class" - - if not num_classes: - num_classes = max(preds.max(), target.max()) + 1 - - # If is_multiclass=False, force to multi-label - if is_multiclass is False: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - else: - target = to_onehot(target, num_classes) - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = to_onehot(preds, num_classes) - preds = preds.reshape(preds.shape[0], preds.shape[1], -1) - - # Multi-dim multi-class (N, C, ...) and (N, ..., C) - else: - mode = "multi-dim multi-class" - if preds.shape[:-1] == target.shape: - shape_permute = list(range(preds.ndim)) - shape_permute[1] = shape_permute[-1] - shape_permute[2:] = range(1, len(shape_permute) - 1) - - preds = preds.permute(*shape_permute) - - num_classes = preds.shape[1] - - if is_multiclass is False: - target = target.reshape(target.shape[0], -1) - preds = select_topk(preds, 1)[:, 1, ...] - preds = preds.reshape(preds.shape[0], -1) - else: - target = to_onehot(target, num_classes) - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) - - return preds, target, mode From d6bc69b77f8efeb177329116d20765404be3f3fc Mon Sep 17 00:00:00 2001 From: Tadej Date: Sun, 3 Jan 2021 13:33:21 +0100 Subject: [PATCH 30/61] Update metrics interface --- .../classification/precision_recall.py | 143 ++++++++++-------- .../metrics/functional/precision_recall.py | 118 ++++++++------- 2 files changed, 138 insertions(+), 123 deletions(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 7660083fae6dd..ad3816ccb5c3d 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Callable +from typing import Optional, Any, Callable, Union import torch from pytorch_lightning.metrics.classification.stat_scores import StatScores @@ -31,15 +31,11 @@ class Precision(StatScores): ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - In case where you need to ignore a class in computing the score, anI ``ignore_index`` - parameter is availible. - Args: average: Defines the reduction that is applied. Should be one of the following: - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - ``'macro'``: Calculate the metric for each class separately, and average the metrics accross classes (with equal weights for each class). - ``'weighted'``: Calculate the metric for each class separately, and average the @@ -69,38 +65,37 @@ class Precision(StatScores): are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + zero_division: + Score to use in the case of a 0 in the denominator in the calculation. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False + before returning the value at the step process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. Example: @@ -120,16 +115,23 @@ def __init__( self, average: str = "micro", mdmc_average: Optional[str] = None, + zero_division: Union[float, int] = 0, + ignore_index: Optional[int] = None, threshold: float = 0.5, num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: int = 0, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + allowed_average = ["micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + if not isinstance(zero_division, (float, int)): + raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, mdmc_reduce=mdmc_average, @@ -143,9 +145,6 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") - self.zero_division = zero_division self.average = average @@ -154,16 +153,21 @@ def compute(self) -> torch.Tensor: Computes the precision score based on inputs passed in to ``update`` previously. Return: - The of the returned tensor depends on the ``average`` parameter + The shape of the returned tensor depends on the ``average`` parameter - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes """ + if isinstance(self.tp, list): + tp = torch.cat(self.tp) + fp = torch.cat(self.fp) + tn = torch.cat(self.tn) + fn = torch.cat(self.fn) + else: + tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn - return _precision_compute( - self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division - ) + return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.zero_division) class Recall(StatScores): @@ -179,15 +183,11 @@ class Recall(StatScores): ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - In case where you need to ignore a class in computing the score, an ``ignore_index`` - parameter is availible. - Args: average: Defines the reduction that is applied. Should be one of the following: - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - ``'macro'``: Calculate the metric for each class separately, and average the metrics accross classes (with equal weights for each class). - ``'weighted'``: Calculate the metric for each class separately, and average the @@ -217,38 +217,37 @@ class Recall(StatScores): are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + zero_division: + Score to use in the case of a 0 in the denominator in the calculation. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False + before returning the value at the step process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. Example: @@ -272,12 +271,19 @@ def __init__( num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, - zero_division: int = 0, + zero_division: Union[float, int] = 0, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + allowed_average = ["micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + if not isinstance(zero_division, (float, int)): + raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, mdmc_reduce=mdmc_average, @@ -291,9 +297,6 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") - self.zero_division = zero_division self.average = average @@ -302,11 +305,19 @@ def compute(self) -> torch.Tensor: Computes the recall score based on inputs passed in to ``update`` previously. Return: - The of the returned tensor depends on the ``average`` parameter + The shape of the returned tensor depends on the ``average`` parameter - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes """ - return _recall_compute(self.tp, self.fp, self.tn, self.fn, self.average, self.mdmc_reduce, self.zero_division) + if isinstance(self.tp, list): + tp = torch.cat(self.tp) + fp = torch.cat(self.fp) + tn = torch.cat(self.tn) + fn = torch.cat(self.fn) + else: + tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn + + return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.zero_division) diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index e7fd2cca3fca8..05350301a484a 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch from pytorch_lightning.metrics.functional.reduction import _reduce_scores @@ -25,7 +25,7 @@ def _precision_compute( fn: torch.Tensor, average: str, mdmc_average: Optional[str], - zero_division: int, + zero_division: Union[float, int], ) -> torch.Tensor: return _reduce_scores( numerator=tp, @@ -46,25 +46,27 @@ def precision( num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, - zero_division: int = 0, + zero_division: Union[float, int] = 0, ) -> torch.Tensor: - """Computes the precision score (the ratio ``tp / (tp + fp)``). + r""" + Computes `Precision `_: + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - In case where you need to ignore a class in computing the score, anI ``ignore_index`` - parameter is availible. - Args: preds: Predictions from model (probabilities or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - ``'macro'``: Calculate the metric for each class separately, and average the metrics accross classes (with equal weights for each class). - ``'weighted'``: Calculate the metric for each class separately, and average the @@ -94,30 +96,28 @@ def precision( are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + zero_division: + Score to use in the case of a 0 in the denominator in the calculation. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. Return: - The of the returned tensor depends on the ``average`` parameter + The shape of the returned tensor depends on the ``average`` parameter - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number @@ -134,12 +134,14 @@ def precision( tensor(0.2500) """ + allowed_average = ["micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - reduce = "macro" if average in ["weighted", "none", None] else average - - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") + if not isinstance(zero_division, (float, int)): + raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + reduce = "macro" if average in ["weighted", "none", None] else average tp, fp, tn, fn = _stat_scores_update( preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index ) @@ -154,7 +156,7 @@ def _recall_compute( fn: torch.Tensor, average: str, mdmc_average: Optional[str], - zero_division: int, + zero_division: Union[float, int], ) -> torch.Tensor: return _reduce_scores( numerator=tp, @@ -175,25 +177,27 @@ def recall( num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, - zero_division: int = 0, + zero_division: Union[float, int] = 0, ) -> torch.Tensor: - """Computes the recall score (the ratio ``tp / (tp + fn)``). + r""" + Computes `Recall `_: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. - In case where you need to ignore a class in computing the score, an ``ignore_index`` - parameter is availible. - Args: preds: Predictions from model (probabilities, or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: - - ``'micro'`` [default]: Calculate the metric globally, by counting the statistics - (tp, fp, tn, fn) accross all samples and classes. + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - ``'macro'``: Calculate the metric for each class separately, and average the metrics accross classes (with equal weights for each class). - ``'weighted'``: Calculate the metric for each class separately, and average the @@ -223,30 +227,28 @@ def recall( are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + zero_division: + Score to use in the case of a 0 in the denominator in the calculation. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 is_multiclass: - If ``False``, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as - binary and multi-label, respectively. If ``True``, treat binary and multi-label inputs - as multi-class or multi-dim multi-class with 2 classes, respectively. - Defaults to ``None``, which treats inputs as they appear. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that - is not in the range ``[0, C-1]``, or if ``C=1``, where ``C`` is the number of classes. - - If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class - will be returned as ``nan`` (to not break the indexing of other labels). - zero_division: - Score to use for classes/samples, whose score has 0 in the denominator. Has to be either - 0 [default] or 1. + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. Return: - The of the returned tensor depends on the ``average`` parameter + The shape of the returned tensor depends on the ``average`` parameter - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number @@ -263,12 +265,14 @@ def recall( tensor(0.2500) """ + allowed_average = ["micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - reduce = "macro" if average in ["weighted", "none", None] else average - - if zero_division not in [0, 1]: - raise ValueError("zero_division has to be either 0 or 1") + if not isinstance(zero_division, (float, int)): + raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + reduce = "macro" if average in ["weighted", "none", None] else average tp, fp, tn, fn = _stat_scores_update( preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index ) From d6559f25c098e2f53a0ead5ff276d2abada08693 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sun, 3 Jan 2021 15:16:39 +0100 Subject: [PATCH 31/61] Add top_k parameter --- .../classification/precision_recall.py | 27 ++++++-- .../metrics/functional/precision_recall.py | 68 ++++++++++++++++--- .../classification/test_precision_recall.py | 8 +-- 3 files changed, 85 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index ad3816ccb5c3d..5e53d6f909102 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -79,6 +79,13 @@ class Precision(StatScores): threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -117,8 +124,9 @@ def __init__( mdmc_average: Optional[str] = None, zero_division: Union[float, int] = 0, ignore_index: Optional[int] = None, - threshold: float = 0.5, num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, @@ -136,6 +144,7 @@ def __init__( reduce="macro" if average in ["weighted", "none", None] else average, mdmc_reduce=mdmc_average, threshold=threshold, + top_k=top_k, num_classes=num_classes, is_multiclass=is_multiclass, ignore_index=ignore_index, @@ -231,6 +240,14 @@ class Recall(StatScores): threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -267,11 +284,12 @@ def __init__( self, average: str = "micro", mdmc_average: Optional[str] = None, - threshold: float = 0.5, + zero_division: Union[float, int] = 0, + ignore_index: Optional[int] = None, num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: Union[float, int] = 0, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -288,6 +306,7 @@ def __init__( reduce="macro" if average in ["weighted", "none", None] else average, mdmc_reduce=mdmc_average, threshold=threshold, + top_k=top_k, num_classes=num_classes, is_multiclass=is_multiclass, ignore_index=ignore_index, diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 05350301a484a..9832bc5d35707 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -42,11 +42,12 @@ def precision( target: torch.Tensor, average: str = "micro", mdmc_average: Optional[str] = None, - threshold: float = 0.5, + zero_division: Union[float, int] = 0, + ignore_index: Optional[int] = None, num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: Union[float, int] = 0, ) -> torch.Tensor: r""" Computes `Precision `_: @@ -110,6 +111,13 @@ def precision( threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -141,9 +149,27 @@ def precision( if not isinstance(zero_division, (float, int)): raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + reduce = "macro" if average in ["weighted", "none", None] else average tp, fp, tn, fn = _stat_scores_update( - preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + is_multiclass=is_multiclass, + ignore_index=ignore_index, ) return _precision_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) @@ -173,11 +199,12 @@ def recall( target: torch.Tensor, average: str = "micro", mdmc_average: Optional[str] = None, - threshold: float = 0.5, + zero_division: Union[float, int] = 0, + ignore_index: Optional[int] = None, num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - zero_division: Union[float, int] = 0, ) -> torch.Tensor: r""" Computes `Recall `_: @@ -241,6 +268,13 @@ def recall( threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's @@ -272,9 +306,27 @@ def recall( if not isinstance(zero_division, (float, int)): raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + reduce = "macro" if average in ["weighted", "none", None] else average tp, fp, tn, fn = _stat_scores_update( - preds, target, reduce, mdmc_average, threshold, num_classes, is_multiclass, ignore_index + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + is_multiclass=is_multiclass, + ignore_index=ignore_index, ) return _recall_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 5c60fb7300c7f..8f5f64d3e7f32 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -17,10 +17,8 @@ _multidim_multiclass_prob_inputs as _mdmc_prob, _multilabel_inputs as _ml, _multilabel_prob_inputs as _ml_prob, - _multilabel_multidim_prob_inputs as _mlmd_prob, - _multilabel_multidim_inputs as _mlmd, ) -from tests.metrics.utils import EXTRA_DIM, NUM_CLASSES, THRESHOLD, MetricTester +from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester torch.manual_seed(42) @@ -127,8 +125,6 @@ def test_wrong_params(metric, fn_metric): (_ml.preds, _ml.target, NUM_CLASSES, False, None, _sk_prec_recall), (_mc_prob.preds, _mc_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), (_multiclass_inputs.preds, _multiclass_inputs.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_mlmd_prob.preds, _mlmd_prob.target, EXTRA_DIM * NUM_CLASSES, None, None, _sk_prec_recall), - (_mlmd.preds, _mlmd.target, EXTRA_DIM * NUM_CLASSES, False, None, _sk_prec_recall), (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), (_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc), (_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc), @@ -136,7 +132,7 @@ def test_wrong_params(metric, fn_metric): ], ) class TestPrecisionRecall(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ddp", [False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_precision_recall_class( self, From 0b8a2fd7e12e1d528e8d955d79a532481c95096f Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 6 Jan 2021 00:39:38 +0100 Subject: [PATCH 32/61] Add back reduce function --- docs/source/metrics.rst | 4 +- .../metrics/classification/helpers.py | 69 ++++++++++++++++++- .../classification/precision_recall.py | 12 ++-- .../metrics/classification/stat_scores.py | 2 +- .../metrics/functional/__init__.py | 3 +- .../metrics/functional/precision_recall.py | 22 +++--- .../metrics/functional/stat_scores.py | 6 +- .../classification/test_stat_scores.py | 2 +- 8 files changed, 92 insertions(+), 28 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 3dbf68583a1e6..9d58ec196551b 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -252,8 +252,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -Using the ``is_multiclass`` parameter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using the is_multiclass parameter +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In some cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label - for example, if both predictions and targets are diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index b9b8d7902976b..c032c7b1b539b 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Tuple, Optional +import numpy as np import torch from pytorch_lightning.metrics.utils import to_onehot, select_topk @@ -249,7 +250,7 @@ def _check_classification_inputs( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. @@ -375,7 +376,7 @@ def _input_format_classification( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. @@ -437,3 +438,67 @@ def _input_format_classification( preds, target = preds.squeeze(-1), target.squeeze(-1) return preds.int(), target.int(), case + + +def _reduce_stat_scores( + numerator: torch.Tensor, + denominator: torch.Tensor, + weights: Optional[torch.Tensor], + average: str, + mdmc_average: Optional[str], + zero_division: int, +) -> torch.Tensor: + """ + Reduces scores of type ``numerator/denominator`` or + ``weights * (numerator/denominator)``, if ``average='weighted'``. + + Args: + numerator: A tensor with numerator numbers. + denominator: A tensor with denominator numbers. If a denominator is + negative, the class will be ignored (if averaging), or its score + will be returned as ``nan`` (if ``average=None``). + If the denominator is zero, then ``zero_division`` score will be + used for those elements. + weights: + A tensor of weights to be used if ``average='weighted'``. + average: + The method to average the scores. Should be one of ``'micro'``, ``'macro'``, + ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior + corresponds to `sklearn averaging methods `__. + mdmc_average: + The method to average the scores if inputs were multi-dimensional multi-class. + Should be either ``'global'`` or ``'samplewise'``. If inputs were not + multi-dimensional multi-class, it should be ``None`` (default). + zero_division: + The value to use for the score if denominator equals zero. + """ + numerator, denominator = numerator.float(), denominator.float() + zero_div_mask = denominator == 0 + ignore_mask = denominator < 0 + + if weights is None: + weights = torch.ones_like(denominator) + else: + weights = weights.float() + + numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) + denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) + weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) + weights = weights / weights.sum(dim=-1) + + scores = weights * (numerator / denominator) + + # This is in case where sum(weights) = 0, which happens if we ignore the only present class + scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) + + if mdmc_average == "samplewise": + scores = scores.mean(dim=0) + ignore_mask = ignore_mask.sum(dim=0).bool() + + if average in ["none", None]: + scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) + else: + scores = scores.sum() + + return scores diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 5e53d6f909102..3fb9e04e1b6ad 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -78,7 +78,7 @@ class Precision(StatScores): threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -89,7 +89,7 @@ class Precision(StatScores): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -133,7 +133,7 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - allowed_average = ["micro", "macro", "weighted", "none", None] + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -239,7 +239,7 @@ class Recall(StatScores): threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -251,7 +251,7 @@ class Recall(StatScores): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -295,7 +295,7 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - allowed_average = ["micro", "macro", "weighted", "none", None] + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index d7b33ce1f8099..dbc9ab2bd714b 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -85,7 +85,7 @@ class StatScores(Metric): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 3aa533504be6b..5a50d81688f28 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -21,9 +21,7 @@ get_num_classes, iou, multiclass_auroc, - precision, precision_recall, - recall, stat_scores_multiple_classes, to_categorical, to_onehot, @@ -38,6 +36,7 @@ from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401 from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401 +from pytorch_lightning.metrics.functional.precision_recall import precision, recall # noqa: F401 from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401 from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401 from pytorch_lightning.metrics.functional.roc import roc # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 9832bc5d35707..2f47870211911 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -14,7 +14,7 @@ from typing import Optional, Union import torch -from pytorch_lightning.metrics.functional.reduction import _reduce_scores +from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update @@ -27,10 +27,10 @@ def _precision_compute( mdmc_average: Optional[str], zero_division: Union[float, int], ) -> torch.Tensor: - return _reduce_scores( + return _reduce_stat_scores( numerator=tp, denominator=tp + fp, - weights=tp + fn, + weights=None if average != "weighted" else tp + fn, average=average, mdmc_average=mdmc_average, zero_division=zero_division, @@ -110,7 +110,7 @@ def precision( threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -121,7 +121,7 @@ def precision( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -142,7 +142,7 @@ def precision( tensor(0.2500) """ - allowed_average = ["micro", "macro", "weighted", "none", None] + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -184,10 +184,10 @@ def _recall_compute( mdmc_average: Optional[str], zero_division: Union[float, int], ) -> torch.Tensor: - return _reduce_scores( + return _reduce_stat_scores( numerator=tp, denominator=tp + fn, - weights=tp + fn, + weights=None if average != "weighted" else tp + fn, average=average, mdmc_average=mdmc_average, zero_division=zero_division, @@ -267,7 +267,7 @@ def recall( threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -278,7 +278,7 @@ def recall( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -299,7 +299,7 @@ def recall( tensor(0.2500) """ - allowed_average = ["micro", "macro", "weighted", "none", None] + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 27d46ee31c39c..a9570cabb3966 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -102,14 +102,14 @@ def _stat_scores_update( target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) # Delete what is in ignore_index, if applicable (and classes don't matter): - if ignore_index and reduce != "macro": + if ignore_index is not None and reduce != "macro": preds = _del_column(preds, ignore_index) target = _del_column(target, ignore_index) tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce) # Take care of ignore_index - if ignore_index and reduce == "macro": + if ignore_index is not None and reduce == "macro": tp[..., ignore_index] = -1 fp[..., ignore_index] = -1 tn[..., ignore_index] = -1 @@ -210,7 +210,7 @@ def stat_scores( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index 62fc8096089a0..0838f4aa54b0b 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -123,7 +123,7 @@ def test_wrong_threshold(): StatScores(threshold=1.5) -@pytest.mark.parametrize("ignore_index", [None, 1]) +@pytest.mark.parametrize("ignore_index", [None, 0]) @pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) @pytest.mark.parametrize( "preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass, top_k", From 0314a62a9d02239e3e8fb3c82968a54e79f6373a Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 8 Jan 2021 20:12:06 +0100 Subject: [PATCH 33/61] Add stuff --- .../metrics/classification/helpers.py | 7 +- .../classification/precision_recall.py | 30 ++++--- .../metrics/functional/precision_recall.py | 30 ++++--- .../classification/test_precision_recall.py | 84 ++++++++----------- .../classification/test_stat_scores.py | 39 +++++++-- 5 files changed, 107 insertions(+), 83 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index c032c7b1b539b..721254cdd6f8a 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -485,11 +485,13 @@ def _reduce_stat_scores( numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) - weights = weights / weights.sum(dim=-1) + + if average not in ["micro", "none", None]: + weights = weights / weights.sum(dim=-1, keepdim=True) scores = weights * (numerator / denominator) - # This is in case where sum(weights) = 0, which happens if we ignore the only present class + # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) if mdmc_average == "samplewise": @@ -501,4 +503,5 @@ def _reduce_stat_scores( else: scores = scores.sum() + # raise ValueError return scores diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 3fb9e04e1b6ad..3556bd60197f1 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Callable, Union +from typing import Optional, Any, Callable import torch from pytorch_lightning.metrics.classification.stat_scores import StatScores @@ -25,7 +25,8 @@ class Precision(StatScores): .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. + false positives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Precision@K. The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the @@ -66,7 +67,8 @@ class Precision(StatScores): were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. zero_division: - Score to use in the case of a 0 in the denominator in the calculation. + Score to use in the case of a 0 in the denominator in the calculation. Should be either + 0 or 1. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute @@ -78,7 +80,7 @@ class Precision(StatScores): threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -122,7 +124,7 @@ def __init__( self, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: Union[float, int] = 0, + zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -137,8 +139,8 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if not isinstance(zero_division, (float, int)): - raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + if zero_division not in [0,1]: + raise ValueError(f"The `zero_division` has to be either 0 or 1.") super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, @@ -186,7 +188,8 @@ class Recall(StatScores): .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. + false negatives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Recall@K. The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the @@ -227,7 +230,8 @@ class Recall(StatScores): were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. zero_division: - Score to use in the case of a 0 in the denominator in the calculation. + Score to use in the case of a 0 in the denominator in the calculation. Should be either + 0 or 1. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute @@ -239,7 +243,7 @@ class Recall(StatScores): threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -284,7 +288,7 @@ def __init__( self, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: Union[float, int] = 0, + zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -299,8 +303,8 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if not isinstance(zero_division, (float, int)): - raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + if zero_division not in [0,1]: + raise ValueError(f"The `zero_division` has to be either 0 or 1.") super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 2f47870211911..86ccc109d3c02 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -25,7 +25,7 @@ def _precision_compute( fn: torch.Tensor, average: str, mdmc_average: Optional[str], - zero_division: Union[float, int], + zero_division: int, ) -> torch.Tensor: return _reduce_stat_scores( numerator=tp, @@ -42,7 +42,7 @@ def precision( target: torch.Tensor, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: Union[float, int] = 0, + zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -55,7 +55,8 @@ def precision( .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. + false positives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Precision@K. The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the @@ -98,7 +99,8 @@ def precision( were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. zero_division: - Score to use in the case of a 0 in the denominator in the calculation. + Score to use in the case of a 0 in the denominator in the calculation. Should be either + 0 or 1. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute @@ -110,7 +112,7 @@ def precision( threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -146,8 +148,8 @@ def precision( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if not isinstance(zero_division, (float, int)): - raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + if zero_division not in [0,1]: + raise ValueError(f"The `zero_division` has to be either 0 or 1.") allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: @@ -182,7 +184,7 @@ def _recall_compute( fn: torch.Tensor, average: str, mdmc_average: Optional[str], - zero_division: Union[float, int], + zero_division: int, ) -> torch.Tensor: return _reduce_stat_scores( numerator=tp, @@ -199,7 +201,7 @@ def recall( target: torch.Tensor, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: Union[float, int] = 0, + zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -212,7 +214,8 @@ def recall( .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. + false negatives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Recall@K. The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the @@ -255,7 +258,8 @@ def recall( were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. zero_division: - Score to use in the case of a 0 in the denominator in the calculation. + Score to use in the case of a 0 in the denominator in the calculation. Should be either + 0 or 1. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute @@ -303,8 +307,8 @@ def recall( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if not isinstance(zero_division, (float, int)): - raise ValueError(f"The `zero_division` has to be a number, got {zero_division}.") + if zero_division not in [0,1]: + raise ValueError(f"The `zero_division` has to be either 0 or 1.") allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 8f5f64d3e7f32..ce0588f2a347e 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,7 +5,7 @@ import torch from sklearn.metrics import precision_score, recall_score -from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics import Precision, Recall from pytorch_lightning.metrics.functional import precision, recall from tests.metrics.classification.inputs import ( @@ -23,9 +23,7 @@ torch.manual_seed(42) -def _sk_prec_recall( - preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average=None -): +def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): if average == "none": average = None if num_classes == 1: @@ -42,7 +40,7 @@ def _sk_prec_recall( ) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=zero_division, labels=labels) + sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) if len(labels) != num_classes and not average: sk_scores = np.insert(sk_scores, ignore_index, np.nan) @@ -50,9 +48,7 @@ def _sk_prec_recall( return sk_scores -def _sk_prec_recall_mdmc( - preds, target, sk_fn, num_classes, average, is_multiclass, zero_division, ignore_index, mdmc_average -): +def _sk_prec_recall_mdmc(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average): preds, target, _ = _input_format_classification( preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass ) @@ -61,61 +57,49 @@ def _sk_prec_recall_mdmc( preds = torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) target = torch.movedim(target, 1, -1).reshape(-1, target.shape[1]) - return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, zero_division, ignore_index) - else: # mdmc_average == "samplewise" + return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) + elif mdmc_average == "samplewise": scores = [] for i in range(preds.shape[0]): pred_i = preds[i, ...].T target_i = target[i, ...].T - scores_i = _sk_prec_recall( - pred_i, target_i, sk_fn, num_classes, average, False, zero_division, ignore_index - ) + scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) scores.append(np.expand_dims(scores_i, 0)) - return np.concatenate(scores).mean() + return np.concatenate(scores).mean(axis=0) @pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) -def test_wrong_params(metric, fn_metric): +@pytest.mark.parametrize( + "zero_division, average, mdmc_average", + [ + (0.5, "micro", None), + (None, "micro", None), + (0, "wrong", None), + (0, "micro", "wrong"), + ], +) +def test_wrong_params(metric, fn_metric, zero_division, average, mdmc_average): with pytest.raises(ValueError): - metric(zero_division=None) + metric(zero_division=zero_division, average=average, mdmc_average=mdmc_average) with pytest.raises(ValueError): - fn_metric(_binary_inputs.preds[0], _binary_inputs.target[0], zero_division=None) - - -###################################################################################### -# Testing for MDMC inputs is partially skipped, because some cases appear where -# (with mdmc_average1 =! None, ignore_index=1, average='weighted') a sample in -# target contains only labels "1" - and as we are ignoring this index, weights of -# all labels will be zero. In this special edge case, sklearn handles the situation -# differently for each metric (recall, precision, fscore), which breaks ours handling -# everything in _reduce_scores (where the return value is 0 in this situation). -###################################################################################### + fn_metric( + _binary_inputs.preds[0], + _binary_inputs.target[0], + zero_division=zero_division, + average=average, + mdmc_average=mdmc_average, + ) @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes, multilabel", - [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False), - (_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, - _sk_prec_recall_multilabel_prob, NUM_CLASSES, True), - (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, - _sk_prec_recall_multiclass_prob, NUM_CLASSES, False), - (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False), - (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, - _sk_prec_recall_multidim_multiclass_prob, NUM_CLASSES, False), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, - _sk_prec_recall_multidim_multiclass, NUM_CLASSES, False), - ], + "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] ) @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("zero_division", [0, 1]) -@pytest.mark.parametrize("ignore_index", [None, 1]) +@pytest.mark.parametrize("ignore_index", [None, 0]) @pytest.mark.parametrize( "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", [ @@ -148,12 +132,14 @@ def test_precision_recall_class( num_classes, average, mdmc_average, - zero_division, ignore_index, ): if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + if average == "weighted" and ignore_index is not None and mdmc_average is not None: pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") @@ -168,7 +154,6 @@ def test_precision_recall_class( average=average, num_classes=num_classes, is_multiclass=is_multiclass, - zero_division=zero_division, ignore_index=ignore_index, mdmc_average=mdmc_average, ), @@ -178,7 +163,6 @@ def test_precision_recall_class( "average": average, "threshold": THRESHOLD, "is_multiclass": is_multiclass, - "zero_division": zero_division, "ignore_index": ignore_index, "mdmc_average": mdmc_average, }, @@ -198,12 +182,14 @@ def test_precision_recall_fn( num_classes, average, mdmc_average, - zero_division, ignore_index, ): if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + if average == "weighted" and ignore_index is not None and mdmc_average is not None: pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") @@ -217,7 +203,6 @@ def test_precision_recall_fn( average=average, num_classes=num_classes, is_multiclass=is_multiclass, - zero_division=zero_division, ignore_index=ignore_index, mdmc_average=mdmc_average, ), @@ -226,7 +211,6 @@ def test_precision_recall_fn( "average": average, "threshold": THRESHOLD, "is_multiclass": is_multiclass, - "zero_division": zero_division, "ignore_index": ignore_index, "mdmc_average": mdmc_average, }, diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index 0838f4aa54b0b..86444461d41a0 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Optional +from typing import Callable, Optional, List import numpy as np import pytest @@ -30,7 +30,7 @@ def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_in ) sk_preds, sk_target = preds.numpy(), target.numpy() - if reduce != "macro" and ignore_index and preds.shape[1] > 1: + if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: sk_preds = np.delete(sk_preds, ignore_index, 1) sk_target = np.delete(sk_target, ignore_index, 1) @@ -55,7 +55,7 @@ def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_in if reduce == "micro": sk_stats = sk_stats[0] - if reduce == "macro" and ignore_index and preds.shape[1]: + if reduce == "macro" and ignore_index is not None and preds.shape[1]: sk_stats[ignore_index, :] = -1 return sk_stats @@ -160,7 +160,7 @@ def test_stat_scores_class( ignore_index: Optional[int], top_k: Optional[int], ): - if ignore_index and preds.ndim == 2: + if ignore_index is not None and preds.ndim == 2: pytest.skip("Skipping ignore_index test with binary inputs.") self.run_class_metric_test( @@ -203,7 +203,7 @@ def test_stat_scores_fn( ignore_index: Optional[int], top_k: Optional[int], ): - if ignore_index and preds.ndim == 2: + if ignore_index is not None and preds.ndim == 2: pytest.skip("Skipping ignore_index test with binary inputs.") self.run_functional_metric_test( @@ -229,3 +229,32 @@ def test_stat_scores_fn( "top_k": top_k, }, ) + + +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +@pytest.mark.parametrize( + "k, preds, target, reduce, expected", + [ + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), + (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor([0, 3, 3, 3, 3])), + (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor([1, 5, 1, 2, 3])), + (1, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), + (2, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), + (1, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), + (2, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), + ], +) +def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor): + """ A simple test to check that top_k works as expected """ + + class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) + class_metric.update(preds, target) + + assert torch.equal(class_metric.compute(), expected.T) + assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) From e96a4423fb8db4ff291d8cc1f2c271b499e9d76c Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 8 Jan 2021 20:18:06 +0100 Subject: [PATCH 34/61] PEP3 --- .../metrics/classification/precision_recall.py | 4 ++-- pytorch_lightning/metrics/functional/precision_recall.py | 6 +++--- tests/metrics/classification/test_stat_scores.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index b794ad51d5ac5..8ecbefe58d38c 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -140,7 +140,7 @@ def __init__( raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") if zero_division not in [0, 1]: - raise ValueError(f"The `zero_division` has to be either 0 or 1.") + raise ValueError("The `zero_division` has to be either 0 or 1.") super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, @@ -304,7 +304,7 @@ def __init__( raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") if zero_division not in [0, 1]: - raise ValueError(f"The `zero_division` has to be either 0 or 1.") + raise ValueError("The `zero_division` has to be either 0 or 1.") super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 86ccc109d3c02..0a76624c6451a 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional import torch from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores @@ -149,7 +149,7 @@ def precision( raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") if zero_division not in [0,1]: - raise ValueError(f"The `zero_division` has to be either 0 or 1.") + raise ValueError("The `zero_division` has to be either 0 or 1.") allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: @@ -312,7 +312,7 @@ def recall( allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index 86444461d41a0..59a95ae7b609f 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Optional, List +from typing import Callable, Optional import numpy as np import pytest From b6de27a5cff9a1caf30082bbe99bf68474a9e7d2 Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 8 Jan 2021 20:32:06 +0100 Subject: [PATCH 35/61] Add depreciation --- docs/source/metrics.rst | 7 +++++ .../metrics/functional/classification.py | 31 +++++++++++++++++++ tests/deprecated_api/test_remove_1-4.py | 15 +++++++++ 3 files changed, 53 insertions(+) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 7c126c4232347..aaac03d6b91f1 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -601,6 +601,13 @@ precision [func] :noindex: +precision_recall [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall + :noindex: + + precision_recall_curve [func] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 788b6ad3f3fab..6421986cf5e2d 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -226,6 +226,11 @@ def precision_recall( """ Computes precision and recall for different thresholds + .. warning :: Deprecated in favor of using + :func:`~pytorch_lightning.metrics.functional.recall` and + :func:`~pytorch_lightning.metrics.functional.precision` separately. + Will be removed in v1.4.0. + Args: pred: estimated probabilities target: ground-truth labels @@ -252,6 +257,12 @@ def precision_recall( (tensor(0.5000), tensor(0.3333)) """ + rank_zero_warn( + "This `precision_recall` was deprecated in v1.2.0 in favor of" + " `using `precision` and `recall` separately." + " It will be removed in v1.4.0", DeprecationWarning + ) + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) @@ -272,6 +283,10 @@ def precision( """ Computes precision score. + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in + v1.4.0. + Args: pred: estimated probabilities target: ground-truth labels @@ -294,6 +309,12 @@ def precision( tensor(0.7500) """ + rank_zero_warn( + "This `precision` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrics.functional import precision`." + " It will be removed in v1.4.0", DeprecationWarning + ) + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] @@ -307,6 +328,10 @@ def recall( """ Computes recall score. + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in + v1.4.0. + Args: pred: estimated probabilities target: ground-truth labels @@ -328,6 +353,12 @@ def recall( >>> recall(x, y) tensor(0.7500) """ + rank_zero_warn( + "This `recall` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrics.functional import recall`." + " It will be removed in v1.4.0", DeprecationWarning + ) + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index db514cd5dde46..e2b966860d612 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -85,3 +85,18 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional.classification import recall + with pytest.deprecated_call(match='will be removed in v1.4'): + recall(torch.randint(0, 2, (10, 3, 3)), + torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional.classification import precision + with pytest.deprecated_call(match='will be removed in v1.4'): + precision(torch.randint(0, 2, (10, 3, 3)), + torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional.classification import precision_recall + with pytest.deprecated_call(match='will be removed in v1.4'): + precision_recall(torch.randint(0, 2, (10, 3, 3)), + torch.randint(0, 2, (10, 3, 3))) From 24adfe84719b16ae1e69a4f254875a55b804ad0e Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 8 Jan 2021 20:38:09 +0100 Subject: [PATCH 36/61] PEP8 --- pytorch_lightning/metrics/functional/classification.py | 10 +++++----- .../metrics/functional/precision_recall.py | 6 +++--- tests/deprecated_api/test_remove_1-4.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 6421986cf5e2d..f42af42818e69 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -226,9 +226,9 @@ def precision_recall( """ Computes precision and recall for different thresholds - .. warning :: Deprecated in favor of using - :func:`~pytorch_lightning.metrics.functional.recall` and - :func:`~pytorch_lightning.metrics.functional.precision` separately. + .. warning :: Deprecated in favor of using + :func:`~pytorch_lightning.metrics.functional.recall` and + :func:`~pytorch_lightning.metrics.functional.precision` separately. Will be removed in v1.4.0. Args: @@ -314,7 +314,7 @@ def precision( " `from pytorch_lightning.metrics.functional import precision`." " It will be removed in v1.4.0", DeprecationWarning ) - + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] @@ -358,7 +358,7 @@ def recall( " `from pytorch_lightning.metrics.functional import recall`." " It will be removed in v1.4.0", DeprecationWarning ) - + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 0a76624c6451a..246706ff145fb 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -148,7 +148,7 @@ def precision( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if zero_division not in [0,1]: + if zero_division not in [0, 1]: raise ValueError("The `zero_division` has to be either 0 or 1.") allowed_mdmc_average = [None, "samplewise", "global"] @@ -307,8 +307,8 @@ def recall( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if zero_division not in [0,1]: - raise ValueError(f"The `zero_division` has to be either 0 or 1.") + if zero_division not in [0, 1]: + raise ValueError("The `zero_division` has to be either 0 or 1.") allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index e2b966860d612..06da28f962e21 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -89,14 +89,14 @@ def test_v1_4_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import recall with pytest.deprecated_call(match='will be removed in v1.4'): recall(torch.randint(0, 2, (10, 3, 3)), - torch.randint(0, 2, (10, 3, 3))) + torch.randint(0, 2, (10, 3, 3))) from pytorch_lightning.metrics.functional.classification import precision with pytest.deprecated_call(match='will be removed in v1.4'): precision(torch.randint(0, 2, (10, 3, 3)), - torch.randint(0, 2, (10, 3, 3))) + torch.randint(0, 2, (10, 3, 3))) from pytorch_lightning.metrics.functional.classification import precision_recall with pytest.deprecated_call(match='will be removed in v1.4'): precision_recall(torch.randint(0, 2, (10, 3, 3)), - torch.randint(0, 2, (10, 3, 3))) + torch.randint(0, 2, (10, 3, 3))) From 660d4b166498ae4cae61112c7deefde271b70cac Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 8 Jan 2021 20:50:41 +0100 Subject: [PATCH 37/61] Deprecate param --- docs/source/metrics.rst | 6 ++--- .../metrics/functional/precision_recall.py | 26 +++++++++++++++++++ tests/deprecated_api/test_remove_1-4.py | 12 +++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index aaac03d6b91f1..c5fdecc1d0992 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -601,10 +601,10 @@ precision [func] :noindex: -precision_recall [func] -~~~~~~~~~~~~~~~~~~~~~~~ +precision_recall [func] +~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall +.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall :noindex: diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 246706ff145fb..2aaafb5643530 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update +from pytorch_lightning.utilities import rank_zero_warn def _precision_compute( @@ -48,6 +49,7 @@ def precision( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, + class_reduction: Optional[str] = None, ) -> torch.Tensor: r""" Computes `Precision `_: @@ -81,6 +83,9 @@ def precision( Note that what is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_average``. + class_reduction: + .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. + mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the ``average`` parameter). Should be one of the following: @@ -126,6 +131,9 @@ def precision( :ref:`documentation section ` for a more detailed explanation and examples. + class_reduction: + .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. + Return: The shape of the returned tensor depends on the ``average`` parameter @@ -144,6 +152,13 @@ def precision( tensor(0.2500) """ + if class_reduction: + rank_zero_warn( + "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" + " `reduce`. It will be removed in v1.4.0", DeprecationWarning + ) + average = class_reduction + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -207,6 +222,7 @@ def recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, + class_reduction: Optional[str] = None, ) -> torch.Tensor: r""" Computes `Recall `_: @@ -285,6 +301,9 @@ def recall( :ref:`documentation section ` for a more detailed explanation and examples. + class_reduction: + .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. + Return: The shape of the returned tensor depends on the ``average`` parameter @@ -303,6 +322,13 @@ def recall( tensor(0.2500) """ + if class_reduction: + rank_zero_warn( + "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" + " `reduce`. It will be removed in v1.4.0", DeprecationWarning + ) + average = class_reduction + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 06da28f962e21..7b00a185f8e60 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -100,3 +100,15 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): precision_recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional import precision + with pytest.deprecated_call(match='will be removed in v1.4'): + precision_recall(torch.randint(0, 2, (10, 3, 3)), + torch.randint(0, 2, (10, 3, 3)), + class_reduction='micro') + + from pytorch_lightning.metrics.functional import recall + with pytest.deprecated_call(match='will be removed in v1.4'): + precision_recall(torch.randint(0, 2, (10, 3)), + torch.randint(0, 2, (10, 3)), + class_reduction='micro') From 6b018d9d1c76b8f872b2db487a16ebaee1f70918 Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 8 Jan 2021 20:50:53 +0100 Subject: [PATCH 38/61] PEP8 --- pytorch_lightning/metrics/functional/precision_recall.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 2aaafb5643530..daeb2a2ec97d5 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -155,7 +155,8 @@ def precision( if class_reduction: rank_zero_warn( "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", DeprecationWarning + " `reduce`. It will be removed in v1.4.0", + DeprecationWarning, ) average = class_reduction @@ -325,7 +326,8 @@ def recall( if class_reduction: rank_zero_warn( "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", DeprecationWarning + " `reduce`. It will be removed in v1.4.0", + DeprecationWarning, ) average = class_reduction From 9fdfcf6b5df7e91fc7c0c01f3f012afaec9a45d6 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sat, 9 Jan 2021 13:27:50 +0100 Subject: [PATCH 39/61] Fix and simplify testing for older PT versions --- .../classification/test_precision_recall.py | 4 +- .../classification/test_stat_scores.py | 8 +- .../metrics/functional/test_classification.py | 76 ------------------- 3 files changed, 4 insertions(+), 84 deletions(-) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index ce0588f2a347e..01bab06596615 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -54,8 +54,8 @@ def _sk_prec_recall_mdmc(preds, target, sk_fn, num_classes, average, is_multicla ) if mdmc_average == "global": - preds = torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) - target = torch.movedim(target, 1, -1).reshape(-1, target.shape[1]) + preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) + target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) elif mdmc_average == "samplewise": diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index 59a95ae7b609f..857dd445d1dfe 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -67,12 +67,8 @@ def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_mul ) if mdmc_reduce == "global": - shape_permute = list(range(preds.ndim)) - shape_permute[1] = shape_permute[-1] - shape_permute[2:] = range(1, len(shape_permute) - 1) - - preds = preds.permute(*shape_permute).reshape(-1, preds.shape[1]) - target = target.permute(*shape_permute).reshape(-1, target.shape[1]) + preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) + target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k) elif mdmc_reduce == "samplewise": diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index caf520105545d..1d5460ce1493e 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -4,17 +4,11 @@ import torch from distutils.version import LooseVersion from sklearn.metrics import ( - precision_score as sk_precision, - recall_score as sk_recall, roc_auc_score as sk_roc_auc_score, ) from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import ( - stat_scores, - stat_scores_multiple_classes, - precision, - recall, dice_score, auroc, multiclass_auroc, @@ -25,8 +19,6 @@ @pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [ - pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'), - pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'), pytest.param(sk_roc_auc_score, auroc, True, id='auroc') ]) def test_against_sklearn(sklearn_metric, torch_metric, only_binary): @@ -56,25 +48,6 @@ def test_against_sklearn(sklearn_metric, torch_metric, only_binary): assert torch.allclose(sk_score, pl_score) -@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted']) -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(sk_precision, precision, id='precision'), - pytest.param(sk_recall, recall, id='recall'), -]) -def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric): - """ Test metrics where the class_reduction parameter have a correponding - value in sklearn """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - pred = torch.randint(10, (300,), device=device) - target = torch.randint(10, (300,), device=device) - sk_score = sklearn_metric(target.cpu().detach().numpy(), - pred.cpu().detach().numpy(), - average=class_reduction) - sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) - pl_score = torch_metric(pred, target, class_reduction=class_reduction) - assert torch.allclose(sk_score, pl_score) - - def test_onehot(): test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) expected = torch.stack([ @@ -122,55 +95,6 @@ def test_get_num_classes(pred, target, num_classes, expected_num_classes): assert get_num_classes(pred, target, num_classes) == expected_num_classes -@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', - 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2) -]) -def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): - tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4) - - assert tp.item() == expected_tp - assert fp.item() == expected_fp - assert tn.item() == expected_tn - assert fn.item() == expected_fn - assert sup.item() == expected_support - - -@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp', - 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none', - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none', - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum', - torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean', - torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8)) -]) -def test_stat_scores_multiclass(pred, target, reduction, - expected_tp, expected_fp, expected_tn, expected_fn, expected_support): - tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction) - - assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) - assert torch.allclose(torch.tensor(expected_fp).to(fp), fp) - assert torch.allclose(torch.tensor(expected_tn).to(tn), tn) - assert torch.allclose(torch.tensor(expected_fn).to(fn), fn) - assert torch.allclose(torch.tensor(expected_support).to(sup), sup) - - -@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ - pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), - pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) -]) -def test_precision_recall(pred, target, expected_prec, expected_rec): - prec = precision(pred, target, class_reduction='none') - rec = recall(pred, target, class_reduction='none') - - assert torch.allclose(torch.tensor(expected_prec).to(prec), prec) - assert torch.allclose(torch.tensor(expected_rec).to(rec), rec) - - @pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ pytest.param(1, 1., 42), pytest.param(None, 1., 42), From 88fd8cc609bf9a149e8cc28f824a069205b6a546 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sat, 9 Jan 2021 13:34:19 +0100 Subject: [PATCH 40/61] Update Changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c6a2c0ab289e..9bb0bdb2f9f47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842)) + + ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) @@ -63,6 +66,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) +- `precision_recall` metris is deprecated in favor of using `precision` and `recall` separately ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842)) + + ### Removed - Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321)) From 6a7b86fe3ed665ff768c9972f4462ef8a5a6157d Mon Sep 17 00:00:00 2001 From: Tadej Date: Sat, 9 Jan 2021 13:36:37 +0100 Subject: [PATCH 41/61] Remove redundant import --- tests/metrics/functional/test_classification.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 1d5460ce1493e..921e5ec8f6978 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,5 +1,3 @@ -from functools import partial - import pytest import torch from distutils.version import LooseVersion From df6365a4cab7d9116e04f68f331bc3b15e9d18a5 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sun, 10 Jan 2021 11:23:28 +0100 Subject: [PATCH 42/61] Add tests to increase coverage --- .../classification/test_precision_recall.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 01bab06596615..25106d4a7d70d 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -73,17 +73,25 @@ def _sk_prec_recall_mdmc(preds, target, sk_fn, num_classes, average, is_multicla @pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) @pytest.mark.parametrize( - "zero_division, average, mdmc_average", + "zero_division, average, mdmc_average, num_classes, ignore_index", [ - (0.5, "micro", None), - (None, "micro", None), - (0, "wrong", None), - (0, "micro", "wrong"), + (0.5, "micro", None, None, None), + (None, "micro", None, None, None), + (0, "wrong", None, None, None), + (0, "micro", "wrong", None, None), + (0, "macro", None, None, None), + (0, "macro", None, 1, 0), ], ) -def test_wrong_params(metric, fn_metric, zero_division, average, mdmc_average): +def test_wrong_params(metric, fn_metric, zero_division, average, mdmc_average, num_classes, ignore_index): with pytest.raises(ValueError): - metric(zero_division=zero_division, average=average, mdmc_average=mdmc_average) + metric( + zero_division=zero_division, + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + ) with pytest.raises(ValueError): fn_metric( @@ -92,6 +100,8 @@ def test_wrong_params(metric, fn_metric, zero_division, average, mdmc_average): zero_division=zero_division, average=average, mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, ) From 5e0dfbd520a6bd2d10cd4139d2769ed7ecfb6693 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 13:55:17 +0100 Subject: [PATCH 43/61] Remove zero_division --- .../metrics/classification/helpers.py | 2 +- .../classification/precision_recall.py | 20 +--------------- .../metrics/functional/precision_recall.py | 24 ++----------------- .../classification/test_precision_recall.py | 16 +++++-------- 4 files changed, 10 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 721254cdd6f8a..3c9a83523e7e7 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -446,7 +446,7 @@ def _reduce_stat_scores( weights: Optional[torch.Tensor], average: str, mdmc_average: Optional[str], - zero_division: int, + zero_division: int = 0, ) -> torch.Tensor: """ Reduces scores of type ``numerator/denominator`` or diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 8ecbefe58d38c..b27f572b80cfc 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -66,10 +66,6 @@ class Precision(StatScores): are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - zero_division: - Score to use in the case of a 0 in the denominator in the calculation. Should be either - 0 or 1. - ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` @@ -124,7 +120,6 @@ def __init__( self, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -139,9 +134,6 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if zero_division not in [0, 1]: - raise ValueError("The `zero_division` has to be either 0 or 1.") - super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, mdmc_reduce=mdmc_average, @@ -156,7 +148,6 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.zero_division = zero_division self.average = average def compute(self) -> torch.Tensor: @@ -229,10 +220,6 @@ class Recall(StatScores): are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - zero_division: - Score to use in the case of a 0 in the denominator in the calculation. Should be either - 0 or 1. - ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` @@ -288,7 +275,6 @@ def __init__( self, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -303,9 +289,6 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if zero_division not in [0, 1]: - raise ValueError("The `zero_division` has to be either 0 or 1.") - super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, mdmc_reduce=mdmc_average, @@ -320,7 +303,6 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.zero_division = zero_division self.average = average def compute(self) -> torch.Tensor: @@ -343,4 +325,4 @@ def compute(self) -> torch.Tensor: else: tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn - return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.zero_division) + return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index daeb2a2ec97d5..fd910f4fa9a69 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -26,7 +26,6 @@ def _precision_compute( fn: torch.Tensor, average: str, mdmc_average: Optional[str], - zero_division: int, ) -> torch.Tensor: return _reduce_stat_scores( numerator=tp, @@ -34,7 +33,6 @@ def _precision_compute( weights=None if average != "weighted" else tp + fn, average=average, mdmc_average=mdmc_average, - zero_division=zero_division, ) @@ -43,7 +41,6 @@ def precision( target: torch.Tensor, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -103,10 +100,6 @@ def precision( are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - zero_division: - Score to use in the case of a 0 in the denominator in the calculation. Should be either - 0 or 1. - ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` @@ -164,9 +157,6 @@ def precision( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if zero_division not in [0, 1]: - raise ValueError("The `zero_division` has to be either 0 or 1.") - allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") @@ -190,7 +180,7 @@ def precision( ignore_index=ignore_index, ) - return _precision_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) + return _precision_compute(tp, fp, tn, fn, average, mdmc_average) def _recall_compute( @@ -200,7 +190,6 @@ def _recall_compute( fn: torch.Tensor, average: str, mdmc_average: Optional[str], - zero_division: int, ) -> torch.Tensor: return _reduce_stat_scores( numerator=tp, @@ -208,7 +197,6 @@ def _recall_compute( weights=None if average != "weighted" else tp + fn, average=average, mdmc_average=mdmc_average, - zero_division=zero_division, ) @@ -217,7 +205,6 @@ def recall( target: torch.Tensor, average: str = "micro", mdmc_average: Optional[str] = None, - zero_division: int = 0, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, @@ -274,10 +261,6 @@ def recall( are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - zero_division: - Score to use in the case of a 0 in the denominator in the calculation. Should be either - 0 or 1. - ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` @@ -335,9 +318,6 @@ def recall( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if zero_division not in [0, 1]: - raise ValueError("The `zero_division` has to be either 0 or 1.") - allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") @@ -361,4 +341,4 @@ def recall( ignore_index=ignore_index, ) - return _recall_compute(tp, fp, tn, fn, average, mdmc_average, zero_division) + return _recall_compute(tp, fp, tn, fn, average, mdmc_average) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 25106d4a7d70d..b4b11473c4d79 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -73,20 +73,17 @@ def _sk_prec_recall_mdmc(preds, target, sk_fn, num_classes, average, is_multicla @pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) @pytest.mark.parametrize( - "zero_division, average, mdmc_average, num_classes, ignore_index", + "average, mdmc_average, num_classes, ignore_index", [ - (0.5, "micro", None, None, None), - (None, "micro", None, None, None), - (0, "wrong", None, None, None), - (0, "micro", "wrong", None, None), - (0, "macro", None, None, None), - (0, "macro", None, 1, 0), + ("wrong", None, None, None), + ("micro", "wrong", None, None), + ("macro", None, None, None), + ("macro", None, 1, 0), ], ) -def test_wrong_params(metric, fn_metric, zero_division, average, mdmc_average, num_classes, ignore_index): +def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index): with pytest.raises(ValueError): metric( - zero_division=zero_division, average=average, mdmc_average=mdmc_average, num_classes=num_classes, @@ -97,7 +94,6 @@ def test_wrong_params(metric, fn_metric, zero_division, average, mdmc_average, n fn_metric( _binary_inputs.preds[0], _binary_inputs.target[0], - zero_division=zero_division, average=average, mdmc_average=mdmc_average, num_classes=num_classes, From 5658ee5aa65edebeeb45f9f1521f64455b54458f Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 14:16:52 +0100 Subject: [PATCH 44/61] fix zero_division --- pytorch_lightning/metrics/classification/precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index b27f572b80cfc..1cc399ebd1289 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -169,7 +169,7 @@ def compute(self) -> torch.Tensor: else: tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn - return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.zero_division) + return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) class Recall(StatScores): From 6ab90022ef63b74192be5c802680ab2f29c01cd8 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 14:37:36 +0100 Subject: [PATCH 45/61] Add zero_div + edge case tests --- .../classification/test_precision_recall.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index b4b11473c4d79..a8419f61922bf 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -101,6 +101,46 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign ) +@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) +def test_zero_division(metric_class, metric_fn): + """ Test that zero_division works correctly (currently should just set to 0). """ + + preds = torch.tensor([1, 2, 1, 1]) + target = torch.tensor([2, 1, 2, 1]) + + cl_metric = metric_class(average="none", num_classes=3) + cl_metric(preds, target) + + result_cl = cl_metric.compute() + result_fn = metric_fn(preds, target, average="none", num_classes=3) + + assert result_cl[0] == result_fn[0] == 0 + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) +def test_no_support(metric_class, metric_fn): + """This tests a rare edge case, where there is only one class present + in target, and ignore_index is set to exactly that class - and the + average method is equal to 'weighted'. + + This would mean that the sum of weights equals zero, and would, without + taking care of this case, return NaN. However, the reduction function + should catch that and set the metric to equal the value of zero_division + in this case (zero_division is for now not configurable and equals 0). + """ + + preds = torch.tensor([1, 1, 0, 0]) + target = torch.tensor([0, 0, 0, 0]) + + cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) + cl_metric(preds, target) + + result_cl = cl_metric.compute() + result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) + + assert result_cl == result_fn == 0 + + @pytest.mark.parametrize( "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] ) From 571cdd8f38f42b595e5df5bade8a937c38ebee3d Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 14:54:14 +0100 Subject: [PATCH 46/61] Reorder cls metric args --- .../classification/precision_recall.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 1cc399ebd1289..bf3c5c5339cf1 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -33,6 +33,11 @@ class Precision(StatScores): multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. average: Defines the reduction that is applied. Should be one of the following: @@ -48,6 +53,8 @@ class Precision(StatScores): Note that what is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_average``. + multilabel: + .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the @@ -71,12 +78,6 @@ class Precision(StatScores): to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class will be returned as ``nan``. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -118,11 +119,12 @@ class Precision(StatScores): def __init__( self, + num_classes: Optional[int] = None, + threshold: float = 0.5, average: str = "micro", + multilabel: bool = False, mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, - threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, compute_on_step: bool = True, @@ -187,6 +189,11 @@ class Recall(StatScores): multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. average: Defines the reduction that is applied. Should be one of the following: @@ -202,6 +209,8 @@ class Recall(StatScores): Note that what is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_average``. + multilabel: + .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the @@ -225,12 +234,6 @@ class Recall(StatScores): to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class will be returned as ``nan``. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label @@ -273,11 +276,12 @@ class Recall(StatScores): def __init__( self, + num_classes: Optional[int] = None, + threshold: float = 0.5, average: str = "micro", + multilabel: bool = False, mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, - threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, compute_on_step: bool = True, From fff6b8bb4d38e77f95ad75bca20586b80dbf911f Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 15:03:17 +0100 Subject: [PATCH 47/61] Add back quotes for is_multiclass --- docs/source/metrics.rst | 4 ++-- pytorch_lightning/metrics/classification/helpers.py | 4 ++-- pytorch_lightning/metrics/classification/precision_recall.py | 4 ++-- pytorch_lightning/metrics/classification/stat_scores.py | 2 +- pytorch_lightning/metrics/functional/precision_recall.py | 4 ++-- pytorch_lightning/metrics/functional/stat_scores.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index bb607cce52ee1..b29dee849e927 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -382,8 +382,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -Using the is_multiclass parameter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using the ``is_multiclass`` parameter +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In some cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label - for example, if both predictions and targets are diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 3c9a83523e7e7..e414d82a5685d 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -250,7 +250,7 @@ def _check_classification_inputs( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. @@ -376,7 +376,7 @@ def _input_format_classification( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index bf3c5c5339cf1..289ab8e9b8d63 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -88,7 +88,7 @@ class Precision(StatScores): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -245,7 +245,7 @@ class Recall(StatScores): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index dbc9ab2bd714b..a7ac0af903371 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -85,7 +85,7 @@ class StatScores(Metric): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index fd910f4fa9a69..0f50ee0676876 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -121,7 +121,7 @@ def precision( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. class_reduction: @@ -282,7 +282,7 @@ def recall( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. class_reduction: diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index a9570cabb3966..c756707037ffa 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -210,7 +210,7 @@ def stat_scores( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: From 46d73634944e6cb682ef53de0736fd7bec6e1091 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 18:03:21 +0100 Subject: [PATCH 48/61] Add precision_recall and tests --- CHANGELOG.md | 3 - docs/source/metrics.rst | 6 +- .../metrics/classification/helpers.py | 4 +- .../classification/precision_recall.py | 4 +- .../metrics/classification/stat_scores.py | 2 +- .../metrics/functional/__init__.py | 2 +- .../metrics/functional/precision_recall.py | 155 +++++++++++++++++- .../metrics/functional/stat_scores.py | 2 +- .../classification/test_precision_recall.py | 79 ++++++++- 9 files changed, 241 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 41eedfc083d96..151ae30e42698 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,9 +86,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) -- `precision_recall` metris is deprecated in favor of using `precision` and `recall` separately ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842)) - - ### Removed - Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321)) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index b29dee849e927..ad9604e1ce933 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -382,8 +382,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -Using the ``is_multiclass`` parameter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using the is_multiclass parameter +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In some cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label - for example, if both predictions and targets are @@ -609,7 +609,7 @@ precision [func] precision_recall [func] ~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall +.. autofunction:: pytorch_lightning.metrics.functional.precision_recall :noindex: diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index e414d82a5685d..3c9a83523e7e7 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -250,7 +250,7 @@ def _check_classification_inputs( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. @@ -376,7 +376,7 @@ def _input_format_classification( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 289ab8e9b8d63..bf3c5c5339cf1 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -88,7 +88,7 @@ class Precision(StatScores): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -245,7 +245,7 @@ class Recall(StatScores): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index a7ac0af903371..dbc9ab2bd714b 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -85,7 +85,7 @@ class StatScores(Metric): is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 909d01b1ca201..47a82b20334e4 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -35,7 +35,7 @@ from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401 from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401 -from pytorch_lightning.metrics.functional.precision_recall import precision, recall # noqa: F401 +from pytorch_lightning.metrics.functional.precision_recall import precision, precision_recall, recall # noqa: F401 from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401 from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401 from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 0f50ee0676876..0f1e41a0258a6 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -121,7 +121,7 @@ def precision( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. class_reduction: @@ -282,7 +282,7 @@ def recall( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. class_reduction: @@ -342,3 +342,154 @@ def recall( ) return _recall_compute(tp, fp, tn, fn, average, mdmc_average) + + +def precision_recall( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, + is_multiclass: Optional[bool] = None, + class_reduction: Optional[str] = None, +) -> torch.Tensor: + r""" + Computes `Precision and Recall `_: + + .. math:: + \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \qquad + \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN} + + Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number + of true positives, false negatives and false positives respecitively. With the use of + ``top_k`` parameter, this metric can generalize to Recall@K and Precision@K. + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`metrics:Input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`metrics:Input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + class_reduction: + .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. + + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a ``(2, )`` tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(2, C)``, where ``C`` stands for the number + of classes + + The first element (in the first dimension) corresponds to precision, the second one to recall. + + Example: + + >>> from pytorch_lightning.metrics.functional import precision_recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision_recall(preds, target, average='macro', num_classes=3) + tensor([0.1667, 0.3333]) + >>> precision_recall(preds, target, average='micro') + tensor([0.2500, 0.2500]) + + """ + if class_reduction: + rank_zero_warn( + "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" + " `reduce`. It will be removed in v1.4.0", + DeprecationWarning, + ) + average = class_reduction + + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + reduce = "macro" if average in ["weighted", "none", None] else average + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ) + + precision = _precision_compute(tp, fp, tn, fn, average, mdmc_average) + recall = _recall_compute(tp, fp, tn, fn, average, mdmc_average) + + return torch.stack([precision, recall]) diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index c756707037ffa..a9570cabb3966 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -210,7 +210,7 @@ def stat_scores( is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index a8419f61922bf..b36e6b637163a 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -7,7 +7,7 @@ from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics import Precision, Recall -from pytorch_lightning.metrics.functional import precision, recall +from pytorch_lightning.metrics.functional import precision, recall, precision_recall from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, @@ -100,6 +100,16 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign ignore_index=ignore_index, ) + with pytest.raises(ValueError): + precision_recall( + _binary_inputs.preds[0], + _binary_inputs.target[0], + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + ) + @pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) def test_zero_division(metric_class, metric_fn): @@ -261,3 +271,70 @@ def test_precision_recall_fn( "mdmc_average": mdmc_average, }, ) + + +@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) +def test_precision_recall_joint(average): + """A simple test of the joint precision_recall metric. + + No need to test this thorougly, as it is just a combination of precision and recall, + which are already tested thoroughly. + """ + + precision_result = precision(_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES) + recall_result = recall(_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES) + + prec_recall_result = precision_recall( + _mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES + ) + + if average is None: + assert prec_recall_result.size() == torch.Size([2, NUM_CLASSES]) + else: + assert prec_recall_result.size() == torch.Size([2]) + + assert torch.equal(torch.stack([precision_result, recall_result]), prec_recall_result) + + +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) +@pytest.mark.parametrize( + "k, preds, target, average, expected_prec, expected_recall", + [ + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)), + (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), + (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)), + ], +) +def test_top_k( + metric_class, + metric_fn, + k: int, + preds: torch.Tensor, + target: torch.Tensor, + average: str, + expected_prec: torch.Tensor, + expected_recall: torch.Tensor, +): + """A simple test to check that top_k works as expected. + + Just a sanity check, the tests in StatScores should already guarantee + the corectness of results. + """ + + class_metric = metric_class(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + + if metric_class.__name__ == "Precision": + result = expected_prec + else: + result = expected_recall + + assert torch.equal(class_metric.compute(), result) + assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) From b6e375d19cbe24e57306485bce127d1a2eaa083c Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 18:07:02 +0100 Subject: [PATCH 49/61] PEP8 --- pytorch_lightning/metrics/functional/__init__.py | 1 - pytorch_lightning/metrics/functional/precision_recall.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 47a82b20334e4..e5bdb2b643cc7 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -18,7 +18,6 @@ dice_score, get_num_classes, multiclass_auroc, - precision_recall, stat_scores_multiple_classes, to_categorical, to_onehot, diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 0f1e41a0258a6..b362f18e76614 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -359,8 +359,8 @@ def precision_recall( r""" Computes `Precision and Recall `_: - .. math:: - \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \qquad + .. math:: + \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \qquad \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN} Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number From 0ef081bad2c6e4fd9990c9f1ab775fd0defe8a71 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 18:13:28 +0100 Subject: [PATCH 50/61] Fix docs --- pytorch_lightning/metrics/functional/precision_recall.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index b362f18e76614..0f2ebbd178e05 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -359,9 +359,10 @@ def precision_recall( r""" Computes `Precision and Recall `_: - .. math:: - \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \qquad - \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN} + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN} Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number of true positives, false negatives and false positives respecitively. With the use of From 3d0c985bcd05b167e61d5f7b26c2835910193f50 Mon Sep 17 00:00:00 2001 From: Tadej Date: Wed, 13 Jan 2021 18:18:05 +0100 Subject: [PATCH 51/61] Fix docs --- pytorch_lightning/metrics/functional/precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 0f2ebbd178e05..cc4c1abbe99c6 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -362,7 +362,7 @@ def precision_recall( .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN} + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number of true positives, false negatives and false positives respecitively. With the use of From e69a71afdb7e003e8779eb0afab56ab67ba701cb Mon Sep 17 00:00:00 2001 From: Tadej Date: Thu, 14 Jan 2021 15:49:36 +0100 Subject: [PATCH 52/61] Update --- pytorch_lightning/metrics/functional/classification.py | 7 +++---- tests/deprecated_api/test_remove_1-4.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index f42af42818e69..2b35bc4b96dab 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -226,9 +226,8 @@ def precision_recall( """ Computes precision and recall for different thresholds - .. warning :: Deprecated in favor of using - :func:`~pytorch_lightning.metrics.functional.recall` and - :func:`~pytorch_lightning.metrics.functional.precision` separately. + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.precision_recall`. Will be removed in v1.4.0. Args: @@ -259,7 +258,7 @@ def precision_recall( """ rank_zero_warn( "This `precision_recall` was deprecated in v1.2.0 in favor of" - " `using `precision` and `recall` separately." + " `from pytorch_lightning.metrcs.functional import precision_recall`." " It will be removed in v1.4.0", DeprecationWarning ) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index b943a6902c217..198b6fe5bfd88 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -127,12 +127,12 @@ def test_v1_4_0_deprecated_metrics(): from pytorch_lightning.metrics.functional import precision with pytest.deprecated_call(match='will be removed in v1.4'): - precision_recall(torch.randint(0, 2, (10, 3, 3)), + precision(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3)), class_reduction='micro') from pytorch_lightning.metrics.functional import recall with pytest.deprecated_call(match='will be removed in v1.4'): - precision_recall(torch.randint(0, 2, (10, 3)), + recall(torch.randint(0, 2, (10, 3)), torch.randint(0, 2, (10, 3)), class_reduction='micro') From b2bd166eb9261cec17ad58de5c6f3bc7e5a0dc92 Mon Sep 17 00:00:00 2001 From: Tadej Date: Thu, 14 Jan 2021 16:00:21 +0100 Subject: [PATCH 53/61] Change precision_recall output --- .../metrics/classification/precision_recall.py | 4 ++-- .../metrics/functional/precision_recall.py | 17 ++++++++--------- .../classification/test_precision_recall.py | 8 ++------ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index bf3c5c5339cf1..fcf34253eefdd 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -105,7 +105,7 @@ class Precision(StatScores): Example: - >>> from pytorch_lightning.metrics.classification import Precision + >>> from pytorch_lightning.metrics import Precision >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> precision = Precision(average='macro', num_classes=3) @@ -262,7 +262,7 @@ class Recall(StatScores): Example: - >>> from pytorch_lightning.metrics.classification import Recall + >>> from pytorch_lightning.metrics import Recall >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> recall = Recall(average='macro', num_classes=3) diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index cc4c1abbe99c6..b49e5d1a62d75 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -436,13 +436,12 @@ def precision_recall( .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a ``(2, )`` tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(2, C)``, where ``C`` stands for the number - of classes + The function returns a tuple with two elements: precision and recall. Their shape + depends on the ``average`` parameter - The first element (in the first dimension) corresponds to precision, the second one to recall. + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, they are a single element tensor + - If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for + the number of classes Example: @@ -450,9 +449,9 @@ def precision_recall( >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> precision_recall(preds, target, average='macro', num_classes=3) - tensor([0.1667, 0.3333]) + (tensor(0.1667), tensor(0.3333)) >>> precision_recall(preds, target, average='micro') - tensor([0.2500, 0.2500]) + (tensor(0.2500), tensor(0.2500)) """ if class_reduction: @@ -493,4 +492,4 @@ def precision_recall( precision = _precision_compute(tp, fp, tn, fn, average, mdmc_average) recall = _recall_compute(tp, fp, tn, fn, average, mdmc_average) - return torch.stack([precision, recall]) + return precision, recall diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 30859078cd1fa..2e34817806679 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -288,12 +288,8 @@ def test_precision_recall_joint(average): _mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES ) - if average is None: - assert prec_recall_result.size() == torch.Size([2, NUM_CLASSES]) - else: - assert prec_recall_result.size() == torch.Size([2]) - - assert torch.equal(torch.stack([precision_result, recall_result]), prec_recall_result) + assert torch.equal(precision_result, prec_recall_result[0]) + assert torch.equal(recall_result, prec_recall_result[1]) _mc_k_target = torch.tensor([0, 1, 2]) From 5eac1f400af3f32fb48011edeadb792b1c2e1b32 Mon Sep 17 00:00:00 2001 From: Tadej Date: Thu, 14 Jan 2021 16:11:04 +0100 Subject: [PATCH 54/61] PEP8/isort --- tests/deprecated_api/test_remove_1-4.py | 1 + .../classification/test_precision_recall.py | 20 ++++++++----------- .../metrics/functional/test_classification.py | 7 ------- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index d9868a3b99c10..249b2cefda3d2 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -142,6 +142,7 @@ def test_v1_4_0_deprecated_metrics(): torch.randint(0, 2, (10, 3)), class_reduction='micro') + class CustomDDPPlugin(DDPPlugin): def configure_ddp(self, model, device_ids): diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 2e34817806679..d0d93b1a8a5b6 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,19 +5,15 @@ import torch from sklearn.metrics import precision_score, recall_score -from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics import Precision, Recall -from pytorch_lightning.metrics.functional import precision, recall, precision_recall -from tests.metrics.classification.inputs import ( - _binary_inputs, - _binary_prob_inputs, - _multiclass_inputs, - _multiclass_prob_inputs as _mc_prob, - _multidim_multiclass_inputs as _mdmc, - _multidim_multiclass_prob_inputs as _mdmc_prob, - _multilabel_inputs as _ml, - _multilabel_prob_inputs as _ml_prob, -) +from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.functional import precision, precision_recall, recall +from tests.metrics.classification.inputs import _binary_inputs, _binary_prob_inputs, _multiclass_inputs +from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob +from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc +from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob +from tests.metrics.classification.inputs import _multilabel_inputs as _ml +from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD torch.manual_seed(42) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 1f8259569df17..fc0f562f2b469 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,10 +1,7 @@ from distutils.version import LooseVersion -from functools import partial import pytest import torch -from sklearn.metrics import precision_score as sk_precision -from sklearn.metrics import recall_score as sk_recall from sklearn.metrics import roc_auc_score as sk_roc_auc_score from pytorch_lightning import seed_everything @@ -13,10 +10,6 @@ auroc, dice_score, multiclass_auroc, - precision, - recall, - stat_scores, - stat_scores_multiple_classes, ) from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve from pytorch_lightning.metrics.utils import get_num_classes, to_categorical, to_onehot From de4fb12ae2ea9ee7677ca3ba08f333aa84140fa9 Mon Sep 17 00:00:00 2001 From: Tadej Date: Thu, 14 Jan 2021 18:18:18 +0100 Subject: [PATCH 55/61] Add method _get_final_stats --- .../classification/precision_recall.py | 22 +++----------- .../metrics/classification/stat_scores.py | 29 ++++++++++++------- tests/deprecated_api/test_remove_1-4.py | 4 +-- .../metrics/functional/test_classification.py | 7 +---- 4 files changed, 26 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index fcf34253eefdd..dedc434a73d02 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Callable +from typing import Any, Callable, Optional import torch + from pytorch_lightning.metrics.classification.stat_scores import StatScores from pytorch_lightning.metrics.functional.precision_recall import _precision_compute, _recall_compute @@ -163,14 +164,7 @@ def compute(self) -> torch.Tensor: - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes """ - if isinstance(self.tp, list): - tp = torch.cat(self.tp) - fp = torch.cat(self.fp) - tn = torch.cat(self.tn) - fn = torch.cat(self.fn) - else: - tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn - + tp, fp, tn, fn = self._get_final_stats() return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) @@ -320,13 +314,5 @@ def compute(self) -> torch.Tensor: - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes """ - - if isinstance(self.tp, list): - tp = torch.cat(self.tp) - fp = torch.cat(self.fp) - tn = torch.cat(self.tn) - fn = torch.cat(self.fn) - else: - tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn - + tp, fp, tn, fn = self._get_final_stats() return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index dbc9ab2bd714b..8eaa388dc3547 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Callable +from typing import Any, Callable, Optional, Tuple import torch + from pytorch_lightning.metrics import Metric -from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update, _stat_scores_compute +from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update class StatScores(Metric): @@ -206,6 +207,21 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.tn.append(tn) self.fn.append(fn) + def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs concatenation on the stat scores if neccesary, + before passing them to a compute function. + """ + + if isinstance(self.tp, list): + tp = torch.cat(self.tp) + fp = torch.cat(self.fp) + tn = torch.cat(self.tn) + fn = torch.cat(self.fn) + else: + tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn + + return tp, fp, tn, fn + def compute(self) -> torch.Tensor: """ Computes the stat scores based on inputs passed in to ``update`` previously. @@ -239,12 +255,5 @@ def compute(self) -> torch.Tensor: - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` """ - if isinstance(self.tp, list): - tp = torch.cat(self.tp) - fp = torch.cat(self.fp) - tn = torch.cat(self.tn) - fn = torch.cat(self.fn) - else: - tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn - + tp, fp, tn, fn = self._get_final_stats() return _stat_scores_compute(tp, fp, tn, fn) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 249b2cefda3d2..ac7e35899fd16 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -132,8 +132,8 @@ def test_v1_4_0_deprecated_metrics(): from pytorch_lightning.metrics.functional import precision with pytest.deprecated_call(match='will be removed in v1.4'): - precision(torch.randint(0, 2, (10, 3, 3)), - torch.randint(0, 2, (10, 3, 3)), + precision(torch.randint(0, 2, (10, 3)), + torch.randint(0, 2, (10, 3)), class_reduction='micro') from pytorch_lightning.metrics.functional import recall diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index fc0f562f2b469..09bb298970297 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -5,12 +5,7 @@ from sklearn.metrics import roc_auc_score as sk_roc_auc_score from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.functional.classification import ( - auc, - auroc, - dice_score, - multiclass_auroc, -) +from pytorch_lightning.metrics.functional.classification import auc, auroc, dice_score, multiclass_auroc from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve from pytorch_lightning.metrics.utils import get_num_classes, to_categorical, to_onehot From 314255702abe5e950b32f91368d854da93b9cc2f Mon Sep 17 00:00:00 2001 From: Tadej Date: Thu, 14 Jan 2021 18:39:19 +0100 Subject: [PATCH 56/61] Fix depr test --- tests/deprecated_api/test_remove_1-4.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index ac7e35899fd16..b614e8b87be20 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -132,14 +132,14 @@ def test_v1_4_0_deprecated_metrics(): from pytorch_lightning.metrics.functional import precision with pytest.deprecated_call(match='will be removed in v1.4'): - precision(torch.randint(0, 2, (10, 3)), - torch.randint(0, 2, (10, 3)), + precision(torch.randint(0, 2, (10,)), + torch.randint(0, 2, (10,)), class_reduction='micro') from pytorch_lightning.metrics.functional import recall with pytest.deprecated_call(match='will be removed in v1.4'): - recall(torch.randint(0, 2, (10, 3)), - torch.randint(0, 2, (10, 3)), + recall(torch.randint(0, 2, (10,)), + torch.randint(0, 2, (10,)), class_reduction='micro') From 371f5bb690ee8a62ba3f0de4f576dbb4e7de4640 Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 15 Jan 2021 12:02:32 +0100 Subject: [PATCH 57/61] Add comment to deprecation tests --- tests/deprecated_api/test_remove_1-4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index b614e8b87be20..6300406088916 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -131,12 +131,14 @@ def test_v1_4_0_deprecated_metrics(): torch.randint(0, 2, (10, 3, 3))) from pytorch_lightning.metrics.functional import precision + # Testing deprecation of class_reduction arg in the *new* precision with pytest.deprecated_call(match='will be removed in v1.4'): precision(torch.randint(0, 2, (10,)), torch.randint(0, 2, (10,)), class_reduction='micro') from pytorch_lightning.metrics.functional import recall + # Testing deprecation of class_reduction arg in the *new* recall with pytest.deprecated_call(match='will be removed in v1.4'): recall(torch.randint(0, 2, (10,)), torch.randint(0, 2, (10,)), From 937ae12abc58ea3487edebfc9a9b662f5a1defe7 Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 15 Jan 2021 12:07:05 +0100 Subject: [PATCH 58/61] isort --- tests/deprecated_api/test_remove_1-4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 6300406088916..bb3590741761e 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -131,6 +131,7 @@ def test_v1_4_0_deprecated_metrics(): torch.randint(0, 2, (10, 3, 3))) from pytorch_lightning.metrics.functional import precision + # Testing deprecation of class_reduction arg in the *new* precision with pytest.deprecated_call(match='will be removed in v1.4'): precision(torch.randint(0, 2, (10,)), @@ -138,6 +139,7 @@ def test_v1_4_0_deprecated_metrics(): class_reduction='micro') from pytorch_lightning.metrics.functional import recall + # Testing deprecation of class_reduction arg in the *new* recall with pytest.deprecated_call(match='will be removed in v1.4'): recall(torch.randint(0, 2, (10,)), From 728980e01e1e6a5c66a9d56495bd0b522d186d7a Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 15 Jan 2021 18:53:57 +0100 Subject: [PATCH 59/61] Apply suggestions from code review Co-authored-by: Jirka Borovec --- pytorch_lightning/metrics/classification/helpers.py | 3 +-- pytorch_lightning/metrics/functional/classification.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 3c9a83523e7e7..f54c487a153ca 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -467,7 +467,7 @@ def _reduce_stat_scores( corresponds to `sklearn averaging methods `__. mdmc_average: - The method to average the scores if inputs were multi-dimensional multi-class. + The method to average the scores if inputs were multi-dimensional multi-class (MDMC). Should be either ``'global'`` or ``'samplewise'``. If inputs were not multi-dimensional multi-class, it should be ``None`` (default). zero_division: @@ -503,5 +503,4 @@ def _reduce_stat_scores( else: scores = scores.sum() - # raise ValueError return scores diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 2b35bc4b96dab..71ba628f568a5 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -283,8 +283,7 @@ def precision( Computes precision score. .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in - v1.4.0. + :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in v1.4.0. Args: pred: estimated probabilities @@ -328,8 +327,7 @@ def recall( Computes recall score. .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in - v1.4.0. + :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in v1.4.0. Args: pred: estimated probabilities From 751144395c0ead516b98eb890531ba0be6903042 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sun, 17 Jan 2021 18:56:29 +0100 Subject: [PATCH 60/61] Add typing to test --- .../classification/test_precision_recall.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index d0d93b1a8a5b6..a0226382552f7 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -1,11 +1,12 @@ from functools import partial +from typing import Callable, Optional import numpy as np import pytest import torch from sklearn.metrics import precision_score, recall_score -from pytorch_lightning.metrics import Precision, Recall +from pytorch_lightning.metrics import Metric, Precision, Recall from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.functional import precision, precision_recall, recall from tests.metrics.classification.inputs import _binary_inputs, _binary_prob_inputs, _multiclass_inputs @@ -172,19 +173,19 @@ class TestPrecisionRecall(MetricTester): @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_precision_recall_class( self, - ddp, - dist_sync_on_step, - preds, - target, - sk_wrapper, - metric_class, - metric_fn, - sk_fn, - is_multiclass, - num_classes, - average, - mdmc_average, - ignore_index, + ddp: bool, + dist_sync_on_step: bool, + preds: torch.Tensor, + target: torch.Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + sk_fn: Callable, + is_multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], ): if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") @@ -224,18 +225,18 @@ def test_precision_recall_class( def test_precision_recall_fn( self, - preds, - target, - sk_wrapper, - metric_class, - metric_fn, - sk_fn, - is_multiclass, - num_classes, - average, - mdmc_average, - ignore_index, - ): + preds: torch.Tensor, + target: torch.Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + sk_fn: Callable, + is_multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ):bool if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") From 67d970287b6120add224b1185fd9924111e6b0c7 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sun, 17 Jan 2021 19:02:01 +0100 Subject: [PATCH 61/61] Add matc str to pytest.raises --- .../classification/test_precision_recall.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index a0226382552f7..17fdd8befc9d5 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -70,16 +70,16 @@ def _sk_prec_recall_mdmc(preds, target, sk_fn, num_classes, average, is_multicla @pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) @pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index", + "average, mdmc_average, num_classes, ignore_index, match_str", [ - ("wrong", None, None, None), - ("micro", "wrong", None, None), - ("macro", None, None, None), - ("macro", None, 1, 0), + ("wrong", None, None, None, "`average`"), + ("micro", "wrong", None, None, "`mdmc"), + ("macro", None, None, None, "number of classes"), + ("macro", None, 1, 0, "ignore_index"), ], ) -def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index): - with pytest.raises(ValueError): +def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): + with pytest.raises(ValueError, match=match_str): metric( average=average, mdmc_average=mdmc_average, @@ -87,7 +87,7 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign ignore_index=ignore_index, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=match_str): fn_metric( _binary_inputs.preds[0], _binary_inputs.target[0], @@ -97,7 +97,7 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign ignore_index=ignore_index, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=match_str): precision_recall( _binary_inputs.preds[0], _binary_inputs.target[0], @@ -236,7 +236,7 @@ def test_precision_recall_fn( average: str, mdmc_average: Optional[str], ignore_index: Optional[int], - ):bool + ): if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")