From 6959ea03c553a6753cc973abb58f240a096bc1c1 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 20:48:57 +0100 Subject: [PATCH 01/38] Add stuff --- .../metrics/classification/utils.py | 385 ++++++++++++++++++ pytorch_lightning/metrics/utils.py | 57 ++- tests/metrics/classification/inputs.py | 18 +- tests/metrics/classification/test_inputs.py | 301 ++++++++++++++ 4 files changed, 738 insertions(+), 23 deletions(-) create mode 100644 pytorch_lightning/metrics/classification/utils.py create mode 100644 tests/metrics/classification/test_inputs.py diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py new file mode 100644 index 0000000000000..b8f5af2e988d8 --- /dev/null +++ b/pytorch_lightning/metrics/classification/utils.py @@ -0,0 +1,385 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Optional + +import numpy as np +import torch + +from pytorch_lightning.metrics.utils import to_onehot, select_topk + + +def _check_classification_inputs( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float, + num_classes: Optional[int] = None, + is_multiclass: bool = False, + top_k: int = 1, +) -> None: + """Performs error checking on inputs for classification. + + This ensures that preds and target take one of the shape/type combinations that are + specified in ``_input_format_classification`` docstring. It also checks the cases of + over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class + cases) that there are only up to 2 distinct labels. + + In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. + + When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, + multi-label, ...), and that, if availible, the implied number of classes in the ``C`` + dimension is consistent with it (as well as that max label in target is smaller than it). + + When ``num_classes`` is not specified in these cases, consistency of the highest target + value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. + + If ``top_k`` is larger than one, then an error is raised if the inputs are not (multi-dim) + multi-class with probability predictions. + + Preds and target tensors are expected to be squeezed already - all dimensions should be + greater than 1, except perhaps the first one (N). + + Args: + preds: tensor with predictions + target: tensor with ground truth labels, always integers + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: number of classes + is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim + multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim + multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. + Defaults to None, which treats inputs as they appear. + """ + + if target.is_floating_point(): + raise ValueError("target has to be an integer tensor") + elif target.min() < 0: + raise ValueError("target has to be a non-negative tensor") + + preds_float = preds.is_floating_point() + if not preds_float and preds.min() < 0: + raise ValueError("if preds are integers, they have to be non-negative") + + if not preds.shape[0] == target.shape[0]: + raise ValueError("preds and target should have the same first dimension.") + + if preds_float: + if preds.min() < 0 or preds.max() > 1: + raise ValueError( + "preds should be probabilities, but values were detected outside of [0,1] range" + ) + + if threshold > 1 or threshold < 0: + raise ValueError("Threshold should be a probability in [0,1]") + + if is_multiclass is False and target.max() > 1: + raise ValueError("If you set is_multiclass=False, then target should not exceed 1.") + + if is_multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") + + # Check that shape/types fall into one of the cases + if len(preds.shape) == len(target.shape): + if preds.shape != target.shape: + raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") + if preds_float and target.max() > 1: + raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") + + # Get the case + if len(preds.shape) == 1 and preds_float: + case = "binary" + elif len(preds.shape) == 1 and not preds_float: + case = "multi-class" + elif len(preds.shape) > 1 and preds_float: + case = "multi-label" + else: + case = "multi-dim multi-class" + + implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) + + elif len(preds.shape) == len(target.shape) + 1: + if not preds_float: + raise ValueError("if preds have one dimension more than target, preds should be a float tensor") + if not preds.shape[:-1] == target.shape: + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "if preds if preds have one dimension more than target, the shape of preds should be" + "either of shape (N, C, ...) or (N, ..., C), and of targets of shape (N, ...)" + ) + + extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] + + if len(preds.shape) == 2: + case = "multi-class" + else: + case = "multi-dim multi-class" + else: + raise ValueError( + "preds and target should both have the (same) shape (N, ...), or target (N, ...)" + " and preds (N, C, ...) or (N, ..., C)" + ) + + if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: + raise ValueError( + "You have set is_multiclass=False, but have more than 2 classes in your data," + " based on the C dimension of preds." + ) + + # Check that num_classes is consistent + if not num_classes: + if preds.shape != target.shape and target.max() >= extra_dim_size: + raise ValueError("The highest label in targets should be smaller than the size of C dimension") + else: + if case == "binary": + if num_classes > 2: + raise ValueError("Your data is binary, but num_classes is larger than 2.") + elif num_classes == 2 and not is_multiclass: + raise ValueError( + "Your data is binary and num_classes=2, but is_multiclass is not True." + "Set it to True if you want to transform binary data to multi-class format." + ) + elif num_classes == 1 and is_multiclass: + raise ValueError( + "You have binary data and have set is_multiclass=True, but num_classes is 1." + "Either leave is_multiclass unset or set it to 2 to transform binary data to multi-class format." + ) + elif "multi-class" in case: + if num_classes == 1 and is_multiclass is not False: + raise ValueError( + "You have set num_classes=1, but predictions are integers." + "If you want to convert (multi-dimensional) multi-class data with 2 classes" + "to binary/multi-label, set is_multiclass=False." + ) + elif num_classes > 1: + if is_multiclass is False: + if implied_classes != num_classes: + raise ValueError( + "You have set is_multiclass=False, but the implied number of classes " + "(from shape of inputs) does not match num_classes. If you are trying to" + "transform multi-dim multi-class data with 2 classes to multi-label, num_classes" + "should be either None or the product of the size of extra dimensions (...)." + "See Input Types in Metrics documentation." + ) + if num_classes <= target.max(): + raise ValueError("The highest label in targets should be smaller than num_classes") + if num_classes <= preds.max(): + raise ValueError("The highest label in preds should be smaller than num_classes") + if preds.shape != target.shape and num_classes != extra_dim_size: + raise ValueError("The size of C dimension of preds does not match num_classes") + + elif case == "multi-label": + if is_multiclass and num_classes != 2: + raise ValueError( + "Your have set is_multiclass=True, but num_classes is not equal to 2." + "If you are trying to transform multi-label data to 2 class multi-dimensional" + "multi-class, you should set num_classes to either 2 or None." + ) + if not is_multiclass and num_classes != implied_classes: + raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + + # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities + if top_k > 1: + if preds.shape == target.shape: + raise ValueError( + "You have set top_k above 1, but your data is not (multi-dimensional) multi-class" + "with probability predictions." + ) + + +def _input_format_classification( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + top_k: int = 1, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor, str]: + """Convert preds and target tensors into common format. + + Preds and targets are supposed to fall into one of these categories (and are + validated to make sure this is the case): + + * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) + * Both preds and target are of shape ``(N,)``, and target is binary, while preds + are a float (binary) + * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and + is integer (multi-class) + * preds and target are of shape ``(N, ...)``, target is binary and preds is a float + (multi-label) + * preds are of shape ``(N, ..., C)`` or ``(N, C, ...)`` and are floats, target is of + shape ``(N, ...)`` and is integer (multi-dimensional multi-class) + * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional + multi-class) + + To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. + + The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` + of ``(N, C, X)``, the details for each case are described below. The function also returns + a ``mode`` string, which describes which of the above cases the inputs belonged to - regardless + of whether this was "overridden" by other settings (like ``is_multiclass``). + + In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed + into a binary tensor (elements become 1 if the probability is greater than or equal to + ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds + become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to + preds first. + + In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets + by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original + shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be + returned as ``(N,1)`` tensor. + + In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with + preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening + all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as + ``(N, 2, C)``, by an equivalent transformation as in the binary case. + + In multi-dimensional multi-class case, normally both target and preds are returned as + ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and + ``C``. The transformations performed here are equivalent to the multi-class case. However, if + ``is_multiclass=False`` (and there are up to two classes), then the data is returned as + ``(N, X)`` binary tensors (multi-label). + + Also, in multi-dimensional multi-class case, if the position of the ``C`` + dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a + ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. + If this is not the case, you should move it from the last to second place using + ``torch.movedim(preds, -1, 1)``. + + Note that where a one-hot transformation needs to be performed and the number of classes + is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be + equal to ``num_classes``, if it is given, or the maximum label value in preds and + target. + + Args: + preds: tensor with predictions + target: tensor with ground truth labels, always integers + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: number of classes + top_k: number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class cases. + is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim + multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim + multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. + Defaults to None, which treats inputs as they appear. + + Returns: + preds: binary tensor of shape (N, C) or (N, C, X) + target: binary tensor of shape (N, C) or (N, C, X) + """ + preds, target = preds.clone().detach(), target.clone().detach() + + # Remove excess dimensions + if preds.shape[0] == 1: + preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) + else: + preds, target = preds.squeeze(), target.squeeze() + + _check_classification_inputs( + preds, + target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + preds_float = preds.is_floating_point() + + if len(preds.shape) == len(target.shape) == 1 and preds_float: + mode = "binary" + preds = (preds >= threshold).int() + + if is_multiclass: + target = to_onehot(target, 2) + preds = to_onehot(preds, 2) + else: + preds = preds.unsqueeze(-1) + target = target.unsqueeze(-1) + + elif len(preds.shape) == len(target.shape) and preds_float: + mode = "multi-label" + preds = (preds >= threshold).int() + + if is_multiclass: + preds = to_onehot(preds, 2).reshape(preds.shape[0], 2, -1) + target = to_onehot(target, 2).reshape(target.shape[0], 2, -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + + elif len(preds.shape) == len(target.shape) + 1 == 2: + mode = "multi-class" + if not num_classes: + num_classes = preds.shape[1] + + target = to_onehot(target, num_classes) + preds = select_topk(preds, top_k) + + # If is_multiclass=False, force to binary + if is_multiclass is False: + target = target[:, [1]] + preds = preds[:, [1]] + + elif len(preds.shape) == len(target.shape) == 1 and not preds_float: + mode = "multi-class" + + if not num_classes: + num_classes = max(preds.max(), target.max()) + 1 + + # If is_multiclass=False, force to binary + if is_multiclass is False: + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) + else: + preds = to_onehot(preds, num_classes) + target = to_onehot(target, num_classes) + + # Multi-dim multi-class (N, ...) with integers + elif preds.shape == target.shape and not preds_float: + mode = "multi-dim multi-class" + + if not num_classes: + num_classes = max(preds.max(), target.max()) + 1 + + # If is_multiclass=False, force to multi-label + if is_multiclass is False: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + else: + target = to_onehot(target, num_classes) + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = to_onehot(preds, num_classes) + preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + + # Multi-dim multi-class (N, C, ...) and (N, ..., C) + else: + mode = "multi-dim multi-class" + if preds.shape[:-1] == target.shape: + preds = torch.movedim(preds, -1, 1) + + num_classes = preds.shape[1] + + if is_multiclass is False: + target = target.reshape(target.shape[0], -1) + preds = select_topk(preds, 1)[:, 1, ...] + preds = preds.reshape(preds.shape[0], -1) + else: + target = to_onehot(target, num_classes) + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) + + return preds, target, mode diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e1ff95b94f471..1ce56b30cf9e5 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -20,6 +20,7 @@ def dim_zero_cat(x): + x = x if isinstance(x, (list, tuple)) else [x] return torch.cat(x, dim=0) @@ -36,8 +37,8 @@ def _flatten(x): def to_onehot( - tensor: torch.Tensor, - num_classes: int, + tensor: torch.Tensor, + num_classes: int, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format @@ -57,24 +58,46 @@ def to_onehot( [0, 0, 0, 1]]) """ dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) +def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: + """ + Convert a probability tensor to binary by selecting top-k highest entries. + + Args: + tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + position defined by the ``dim`` argument + topk: number of highest entries to turn into 1s + dim: dimension on which to compare entries + + Output: + A binary tensor of the same shape as the input tensor of type torch.int32 + + Example: + >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> select_topk(x, topk=2) + tensor([[0, 1, 1], + [1, 1, 0]], dtype=torch.int32) + """ + zeros = torch.zeros_like(tensor, device=tensor.device) + topk_tensor = zeros.scatter(1, tensor.topk(k=topk, dim=dim).indices, 1.0) + + return topk_tensor.int() + + def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): """ Check that predictions and target have the same shape, else raise error """ if pred.shape != target.shape: - raise RuntimeError('Predictions and targets are expected to have the same shape') + raise RuntimeError("Predictions and targets are expected to have the same shape") def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5 + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into label tensors + """Convert preds and target tensors into label tensors Args: preds: either tensor with labels, tensor with probabilities/logits or @@ -87,9 +110,7 @@ def _input_format_classification( target: tensor with labels """ if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") if len(preds.shape) == len(target.shape) + 1: # multi class probabilites @@ -102,13 +123,9 @@ def _input_format_classification( def _input_format_classification_one_hot( - num_classes: int, - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - multilabel: bool = False + num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, multilabel: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into one hot spare label tensors + """Convert preds and target tensors into one hot spare label tensors Args: num_classes: number of classes @@ -123,9 +140,7 @@ def _input_format_classification_one_hot( target: one hot tensors of shape [num_classes, -1] with true labels """ if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") if len(preds.shape) == len(target.shape) + 1: # multi class probabilites diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index 9613df3b6f8ca..e648aaf10093e 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -29,12 +29,21 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) _multilabel_inputs = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) @@ -61,8 +70,13 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) +# Class dimension last +_multidim_multiclass_prob_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) _multidim_multiclass_inputs = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)) + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py new file mode 100644 index 0000000000000..8d17d5624fac0 --- /dev/null +++ b/tests/metrics/classification/test_inputs.py @@ -0,0 +1,301 @@ +import pytest +import torch +from torch import randint, rand + +from pytorch_lightning.metrics.utils import to_onehot, select_topk +from pytorch_lightning.metrics.classification.utils import _input_format_classification +from tests.metrics.classification.inputs import ( + Input, + _binary_inputs as _bin, + _binary_prob_inputs as _bin_prob, + _multiclass_inputs as _mc, + _multiclass_prob_inputs as _mc_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, + _multidim_multiclass_prob_inputs1 as _mdmc_prob1, + _multilabel_inputs as _ml, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_inputs as _mlmd, + _multilabel_multidim_prob_inputs as _mlmd_prob, +) +from tests.metrics.utils import NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, THRESHOLD + +torch.manual_seed(42) + +# Some additional inputs to test on +_mc_prob_2cls = Input(rand(NUM_BATCHES, BATCH_SIZE, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) +_mdmc_prob_many_dims = Input( + rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM), + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) +_mdmc_prob_many_dims1 = Input( + rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM, NUM_CLASSES), + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) +_mdmc_prob_2cls = Input( + rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) +_mdmc_prob_2cls1 = Input( + rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) + +# Some utils +T = torch.Tensor + + +def idn(x): + return x + + +def usq(x): + return x.unsqueeze(-1) + + +def toint(x): + return x.int() + + +def thrs(x): + return x >= THRESHOLD + + +def rshp1(x): + return x.reshape(x.shape[0], -1) + + +def rshp2(x): + return x.reshape(x.shape[0], x.shape[1], -1) + + +def onehot(x): + return to_onehot(x, NUM_CLASSES) + + +def onehot2(x): + return to_onehot(x, 2) + + +def top1(x): + return select_topk(x, 1) + + +def top2(x): + return select_topk(x, 2) + + +def mvdim(x): + return torch.movedim(x, -1, 1) + + +# To avoid ugly black line wrapping +def ml_preds_tr(x): + return rshp1(toint(thrs(x))) + + +def onehot_rshp1(x): + return onehot(rshp1(x)) + + +def onehot2_rshp1(x): + return onehot2(rshp1(x)) + + +def top1_rshp2(x): + return top1(rshp2(x)) + + +def top2_rshp2(x): + return top2(rshp2(x)) + + +def mdmc1_top1_tr(x): + return top1(rshp2(mvdim(x))) + + +def mdmc1_top2_tr(x): + return top2(rshp2(mvdim(x))) + + +def probs_to_mc_preds_tr(x): + return toint(onehot2(thrs(x))) + + +def mlmd_prob_to_mc_preds_tr(x): + return onehot2(rshp1(toint(thrs(x)))) + + +def mdmc_prob_to_ml_preds_tr(x): + return top1(mvdim(x))[:, 1] + + +######################## +# Test correct inputs +######################## + + +@pytest.mark.parametrize( + "inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + [ + ############################# + # Test usual expected cases + (_bin, THRESHOLD, None, False, 1, "multi-class", usq, usq), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: usq(toint(thrs(x))), usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: toint(thrs(x)), idn), + (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", idn, idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", ml_preds_tr, rshp1), + (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", rshp1, rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", onehot, onehot), + (_mc_prob, THRESHOLD, None, None, 1, "multi-class", top1, onehot), + (_mc_prob, THRESHOLD, None, None, 2, "multi-class", top2, onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", onehot, onehot), + (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot), + (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot_rshp1), + # Test with C dim in last place + (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot), + (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot_rshp1), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot_rshp1), + ########################### + # Test some special cases + # Binary as multiclass + (_bin, THRESHOLD, None, None, 1, "multi-class", onehot2, onehot2), + # Binary probs as multiclass + (_bin_prob, THRESHOLD, None, True, 1, "binary", probs_to_mc_preds_tr, onehot2), + # Multilabel as multiclass + (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2, onehot2), + # Multilabel probs as multiclass + (_ml_prob, THRESHOLD, None, True, 1, "multi-label", probs_to_mc_preds_tr, onehot2), + # Multidim multilabel as multiclass + (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2_rshp1, onehot2_rshp1), + # Multidim multilabel probs as multiclass + (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", mlmd_prob_to_mc_preds_tr, onehot2_rshp1), + # Multiclass prob with 2 classes as binary + (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: top1(x)[:, [1]], usq), + # Multi-dim multi-class with 2 classes as multi-label + (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: top1(x)[:, 1], idn), + (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", mdmc_prob_to_ml_preds_tr, idn), + ], +) +def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0], + target=inputs.target[0], + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0])) + assert torch.equal(target_out, post_target(inputs.target[0])) + + # Test that things work when batch_size = 1 + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0][[0], ...], + target=inputs.target[0][[0], ...], + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...])) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...])) + + +# Test that threshold is correctly applied +def test_threshold(): + target = T([1, 1, 1]).int() + preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5]) + + preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) + + assert torch.equal(torch.tensor([0, 1, 1]), preds_probs_out.squeeze().long()) + + +######################################################################## +# Test incorrect inputs +######################################################################## + + +@pytest.mark.parametrize( + "preds, target, threshold, num_classes, is_multiclass, top_k", + [ + # Target not integer + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, 1), + # Target negative + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, 1), + # Preds negative integers + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Negative probabilities + (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Threshold outside of [0,1] + (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, 1), + # is_multiclass=False and target > 1 + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, 1), + # is_multiclass=False and preds integers with > 1 + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, 1), + # Wrong batch size + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Completely wrong shape + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + # Same #dims, different shape + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + # Same shape and preds floats, target not binary + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, 1), + # #dims in preds = 1 + #dims in target, C shape not second or last + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + # #dims in preds = 1 + #dims in target, preds not float + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + # is_multiclass=False, with C dimension > 2 + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, 1), + # Max target larger or equal to C dimension + (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, 1), + # C dimension not equal to num_classes + (rand(size=(7, 3, 4)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), + # Max target larger than num_classes (with #dim preds = 1 + #dims target) + (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + # Max target larger than num_classes (with #dim preds = #dims target) + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + # Max preds larger than num_classes (with #dim preds = #dims target) + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, 1), + # Num_classes=1, but is_multiclass not false + (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, 1, None, 1), + # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + # Multilabel input with implied class dimension != num_classes + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, 1), + # Binary input, num_classes > 2 + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, 1), + # Binary input, num_classes == 2 and is_multiclass not True + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, 1), + # Binary input, num_classes == 1 and is_multiclass=True + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, 1), + # Topk > 1 with non (md)mc prob data + (_bin.preds[0], _bin.target[0], 0.5, None, None, 2), + (_bin_prob.preds[0], _bin_prob.target[0], 0.5, None, None, 2), + (_mc.preds[0], _mc.target[0], 0.5, None, None, 2), + (_ml.preds[0], _ml.target[0], 0.5, None, None, 2), + (_mlmd.preds[0], _mlmd.target[0], 0.5, None, None, 2), + (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), + (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), + (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), + ], +) +def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, + target=target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) From 06790157ed025ba98950b9616ef43835f9b6c4e7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 21:47:30 +0100 Subject: [PATCH 02/38] Change metrics documentation layout --- docs/source/metrics.rst | 188 ++++++++++++++++++++++++++-------------- 1 file changed, 121 insertions(+), 67 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d47c872f35047..407b64d3d2948 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -196,12 +196,71 @@ Metric API .. autoclass:: pytorch_lightning.metrics.Metric :noindex: -************* -Class metrics -************* +*************************** +Class vs Functional Metrics +*************************** +The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. + +Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. +If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. + +********************** Classification Metrics ----------------------- +********************** + +Input types +----------- + +For the purposes of classification metrics, inputs (predictions and targets) are split +into these categories (``N`` stands for the batch size and ``C`` for number of classes): + +.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 + :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype" + :widths: 20, 10, 10, 10, 10 + + "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" + "Multi-class", "(N,)", "``int``", "(N,)", "``int``" + "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" + "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" + "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with probabilities", "(N, C, ...) or (N, ..., C)", "``float``", "(N, ...)", "``int``" + +.. note:: + All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so + that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. + +When predictions or targets are integers, it is assumed that class labels start at , i.e. +the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types + +.. code-block:: python + + # Binary inputs + binary_preds = torch.tensor([0.6, 0.1, 0.9]) + binary_target = torch.tensor([1, 0, 2]) + + # Multi-class inputs + mc_preds = torch.tensor([0, 2, 1]) + mc_target = torch.tensor([0, 1, 2]) + + # Multi-class inputs with probabilities + mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) + mc_target_probs = torch.tensor([0, 1, 2]) + + # Multi-label inputs + ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) + ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) + +In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class, +but are actually binary/multi-label. For example, if both predictions and targets are 1d +binary tensors. Or it could be the other way around, you want to treat binary/multi-label +inputs as 2-class (multi-dimensional) multi-class inputs. + +For these cases, the metrics where this distinction would make a difference, expose the +``is_multiclass`` argument. + +Class Metrics (Classification) +------------------------------ Accuracy ~~~~~~~~ @@ -239,61 +298,8 @@ ConfusionMatrix .. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix :noindex: -Regression Metrics ------------------- - -MeanSquaredError -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError - :noindex: - - -MeanAbsoluteError -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError - :noindex: - - -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError - :noindex: - - -ExplainedVariance -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance - :noindex: - - -PSNR -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.PSNR - :noindex: - - -SSIM -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.SSIM - :noindex: - -****************** -Functional Metrics -****************** - -The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. - -Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. - -Classification --------------- +Functional Metrics (Classification) +----------------------------------- accuracy [func] ~~~~~~~~~~~~~~~ @@ -434,9 +440,57 @@ to_onehot [func] .. autofunction:: pytorch_lightning.metrics.functional.classification.to_onehot :noindex: +****************** +Regression Metrics +****************** + +Class Metrics (Regression) +-------------------------- + +MeanSquaredError +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError + :noindex: + + +MeanAbsoluteError +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError + :noindex: + + +MeanSquaredLogError +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError + :noindex: + + +ExplainedVariance +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance + :noindex: + + +PSNR +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.PSNR + :noindex: + -Regression ----------- +SSIM +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.SSIM + :noindex: + + +Functional Metrics (Regression) +------------------------------- explained_variance [func] ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -479,22 +533,22 @@ ssim [func] .. autofunction:: pytorch_lightning.metrics.functional.ssim :noindex: - +*** NLP ---- +*** bleu_score [func] -~~~~~~~~~~~~~~~~~ +----------------- .. autofunction:: pytorch_lightning.metrics.functional.nlp.bleu_score :noindex: - +******** Pairwise --------- +******** embedding_similarity [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +--------------------------- .. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity :noindex: From 55fdaaf16a185d5ddf2cf37871f389658b8de2c4 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:41:43 +0100 Subject: [PATCH 03/38] Change testing utils --- tests/metrics/utils.py | 161 +++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 79 deletions(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 34abee8473863..8ec14c41b1360 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -21,10 +21,10 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -41,23 +41,23 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -71,28 +71,28 @@ def _class_test( if metric.dist_sync_on_step: if rank == 0: - ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) - ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) else: sk_batch_result = sk_metric(preds[i], target[i]) # assert for batch if check_batch: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) # check on all batches on all ranks result = metric.compute() assert isinstance(result, torch.Tensor) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) # assert after aggregation - assert np.allclose(result.numpy(), sk_result, atol=atol) + assert np.allclose(result.numpy(), sk_result, atol=atol, equal_nan=True) def _functional_test( @@ -101,17 +101,17 @@ def _functional_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = {}, - atol: float = 1e-8 + atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -120,26 +120,27 @@ def _functional_test( sk_result = sk_metric(preds[i], target[i]) # assert its the same - assert np.allclose(lightning_result.numpy(), sk_result, atol=atol) + assert np.allclose(lightning_result.numpy(), sk_result, atol=atol, equal_nan=True) class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ + atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: - set_start_method('spawn') + set_start_method("spawn") except RuntimeError: pass self.poolSize = NUM_PROCESSES @@ -157,24 +158,26 @@ def run_functional_metric_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {} + metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ - _functional_test(preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=sk_metric, - metric_args=metric_args, - atol=self.atol) + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) def run_class_metric_test( self, @@ -188,22 +191,22 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": From 5cbf56a5422d0efc29ad84d85d5768863692496b Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:50:50 +0100 Subject: [PATCH 04/38] Replace len(*.shape) with *.ndim --- .../metrics/classification/utils.py | 20 +++++++++---------- pytorch_lightning/metrics/utils.py | 16 +++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index b8f5af2e988d8..62fe3e2095a78 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -90,25 +90,25 @@ def _check_classification_inputs( raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") # Check that shape/types fall into one of the cases - if len(preds.shape) == len(target.shape): + if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") if preds_float and target.max() > 1: raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") # Get the case - if len(preds.shape) == 1 and preds_float: + if preds.ndim == 1 and preds_float: case = "binary" - elif len(preds.shape) == 1 and not preds_float: + elif preds.ndim == 1 and not preds_float: case = "multi-class" - elif len(preds.shape) > 1 and preds_float: + elif preds.ndim > 1 and preds_float: case = "multi-label" else: case = "multi-dim multi-class" implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) - elif len(preds.shape) == len(target.shape) + 1: + elif preds.ndim == target.ndim + 1: if not preds_float: raise ValueError("if preds have one dimension more than target, preds should be a float tensor") if not preds.shape[:-1] == target.shape: @@ -120,7 +120,7 @@ def _check_classification_inputs( extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] - if len(preds.shape) == 2: + if preds.ndim == 2: case = "multi-class" else: case = "multi-dim multi-class" @@ -299,7 +299,7 @@ def _input_format_classification( preds_float = preds.is_floating_point() - if len(preds.shape) == len(target.shape) == 1 and preds_float: + if preds.ndim == target.ndim == 1 and preds_float: mode = "binary" preds = (preds >= threshold).int() @@ -310,7 +310,7 @@ def _input_format_classification( preds = preds.unsqueeze(-1) target = target.unsqueeze(-1) - elif len(preds.shape) == len(target.shape) and preds_float: + elif preds.ndim == target.ndim and preds_float: mode = "multi-label" preds = (preds >= threshold).int() @@ -321,7 +321,7 @@ def _input_format_classification( preds = preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) - elif len(preds.shape) == len(target.shape) + 1 == 2: + elif preds.ndim == target.ndim + 1 == 2: mode = "multi-class" if not num_classes: num_classes = preds.shape[1] @@ -334,7 +334,7 @@ def _input_format_classification( target = target[:, [1]] preds = preds[:, [1]] - elif len(preds.shape) == len(target.shape) == 1 and not preds_float: + elif preds.ndim == target.ndim == 1 and not preds_float: mode = "multi-class" if not num_classes: diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 1ce56b30cf9e5..0f71c531fe6ac 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -109,14 +109,14 @@ def _input_format_classification( preds: tensor with labels target: tensor with labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + if preds.ndim == target.ndim and preds.dtype == torch.float: # binary or multilabel probablities preds = (preds >= threshold).long() return preds, target @@ -139,24 +139,24 @@ def _input_format_classification_one_hot( preds: one hot tensor of shape [num_classes, -1] with predicted labels target: one hot tensors of shape [num_classes, -1] with true labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.long and num_classes > 1 and not multilabel: + if preds.ndim == target.ndim and preds.dtype == torch.long and num_classes > 1 and not multilabel: # multi-class preds = to_onehot(preds, num_classes=num_classes) target = to_onehot(target, num_classes=num_classes) - elif len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + elif preds.ndim == target.ndim and preds.dtype == torch.float: # binary or multilabel probablities preds = (preds >= threshold).long() # transpose class as first dim and reshape - if len(preds.shape) > 1: + if preds.ndim > 1: preds = preds.transpose(1, 0) target = target.transpose(1, 0) From 9c33d0b3c14e50a984c59596eff7f6a513f19e2c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:55:56 +0100 Subject: [PATCH 05/38] More descriptive error message for input formatting --- pytorch_lightning/metrics/classification/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 62fe3e2095a78..6a47c7adbbae1 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -76,9 +76,7 @@ def _check_classification_inputs( if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError( - "preds should be probabilities, but values were detected outside of [0,1] range" - ) + raise ValueError("preds should be probabilities, but values were detected outside of [0,1] range") if threshold > 1 or threshold < 0: raise ValueError("Threshold should be a probability in [0,1]") @@ -92,7 +90,10 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases if preds.ndim == target.ndim: if preds.shape != target.shape: - raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") + raise ValueError( + "preds and targets should have the same shape", + f" got preds shape = {preds.shape} and target shape = {target.shape}.", + ) if preds_float and target.max() > 1: raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") From 65622058778a8186bdc8b7f8852fc4690405de09 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:24:22 +0100 Subject: [PATCH 06/38] Replace movedim with permute --- pytorch_lightning/metrics/classification/utils.py | 10 +++++++--- tests/metrics/classification/test_inputs.py | 7 ++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 6a47c7adbbae1..f58e61dcd3127 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Tuple, Optional -import numpy as np import torch from pytorch_lightning.metrics.utils import to_onehot, select_topk @@ -256,7 +255,8 @@ def _input_format_classification( dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. If this is not the case, you should move it from the last to second place using - ``torch.movedim(preds, -1, 1)``. + ``torch.movedim(preds, -1, 1)``, or using ``preds.permute``, if you are using an older + version of Pytorch. Note that where a one-hot transformation needs to be performed and the number of classes is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be @@ -370,7 +370,11 @@ def _input_format_classification( else: mode = "multi-dim multi-class" if preds.shape[:-1] == target.shape: - preds = torch.movedim(preds, -1, 1) + shape_permute = list(range(preds.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + preds = preds.permute(*shape_permute) num_classes = preds.shape[1] diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8d17d5624fac0..19828dc07eb34 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -84,7 +84,12 @@ def top2(x): def mvdim(x): - return torch.movedim(x, -1, 1) + """ Equivalent of torch.movedim(x, -1, 1) """ + shape_permute = list(range(x.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + return x.permute(*shape_permute) # To avoid ugly black line wrapping From a04a71ea195c601d59b22c7295c6d1389d7155fe Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:39:18 +0100 Subject: [PATCH 07/38] Style changes in error messages --- .../metrics/classification/utils.py | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index f58e61dcd3127..398acbf2bc5fc 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -62,39 +62,40 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("target has to be an integer tensor") + raise ValueError("`target` has to be an integer tensor") elif target.min() < 0: - raise ValueError("target has to be a non-negative tensor") + raise ValueError("`target` has to be a non-negative tensor") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if preds are integers, they have to be non-negative") + raise ValueError("if `preds` are integers, they have to be non-negative") if not preds.shape[0] == target.shape[0]: - raise ValueError("preds and target should have the same first dimension.") + raise ValueError("`preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("preds should be probabilities, but values were detected outside of [0,1] range") + raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range") if threshold > 1 or threshold < 0: raise ValueError("Threshold should be a probability in [0,1]") if is_multiclass is False and target.max() > 1: - raise ValueError("If you set is_multiclass=False, then target should not exceed 1.") + raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") + raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") # Check that shape/types fall into one of the cases if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError( - "preds and targets should have the same shape", - f" got preds shape = {preds.shape} and target shape = {target.shape}.", + "`preds` and `target` should have the same shape", + f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: - raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") + raise ValueError("if `preds` and `target` are of shape (N, ...)" + " and `preds` are floats, `target` should be binary") # Get the case if preds.ndim == 1 and preds_float: @@ -110,12 +111,12 @@ def _check_classification_inputs( elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if preds have one dimension more than target, preds should be a float tensor") + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor") if not preds.shape[:-1] == target.shape: if preds.shape[2:] != target.shape[1:]: raise ValueError( - "if preds if preds have one dimension more than target, the shape of preds should be" - "either of shape (N, C, ...) or (N, ..., C), and of targets of shape (N, ...)" + "if `preds` have one dimension more than `target`, the shape of `preds` should be" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)" ) extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] @@ -126,64 +127,64 @@ def _check_classification_inputs( case = "multi-dim multi-class" else: raise ValueError( - "preds and target should both have the (same) shape (N, ...), or target (N, ...)" - " and preds (N, C, ...) or (N, ..., C)" + "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" + " and `preds` (N, C, ...) or (N, ..., C)" ) if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: raise ValueError( - "You have set is_multiclass=False, but have more than 2 classes in your data," - " based on the C dimension of preds." + "You have set `is_multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." ) # Check that num_classes is consistent if not num_classes: if preds.shape != target.shape and target.max() >= extra_dim_size: - raise ValueError("The highest label in targets should be smaller than the size of C dimension") + raise ValueError("The highest label in `target` should be smaller than the size of C dimension") else: if case == "binary": if num_classes > 2: - raise ValueError("Your data is binary, but num_classes is larger than 2.") + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") elif num_classes == 2 and not is_multiclass: raise ValueError( - "Your data is binary and num_classes=2, but is_multiclass is not True." - "Set it to True if you want to transform binary data to multi-class format." + "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." ) elif num_classes == 1 and is_multiclass: raise ValueError( - "You have binary data and have set is_multiclass=True, but num_classes is 1." - "Either leave is_multiclass unset or set it to 2 to transform binary data to multi-class format." + "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." + " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." ) elif "multi-class" in case: if num_classes == 1 and is_multiclass is not False: raise ValueError( - "You have set num_classes=1, but predictions are integers." - "If you want to convert (multi-dimensional) multi-class data with 2 classes" - "to binary/multi-label, set is_multiclass=False." + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `is_multiclass=False`." ) elif num_classes > 1: if is_multiclass is False: if implied_classes != num_classes: raise ValueError( - "You have set is_multiclass=False, but the implied number of classes " - "(from shape of inputs) does not match num_classes. If you are trying to" - "transform multi-dim multi-class data with 2 classes to multi-label, num_classes" - "should be either None or the product of the size of extra dimensions (...)." - "See Input Types in Metrics documentation." + "You have set `is_multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." ) if num_classes <= target.max(): - raise ValueError("The highest label in targets should be smaller than num_classes") + raise ValueError("The highest label in `target` should be smaller than `num_classes`") if num_classes <= preds.max(): - raise ValueError("The highest label in preds should be smaller than num_classes") + raise ValueError("The highest label in `preds` should be smaller than `num_classes`") if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of preds does not match num_classes") + raise ValueError("The size of C dimension of `preds` does not match `num_classes`") elif case == "multi-label": if is_multiclass and num_classes != 2: raise ValueError( - "Your have set is_multiclass=True, but num_classes is not equal to 2." - "If you are trying to transform multi-label data to 2 class multi-dimensional" - "multi-class, you should set num_classes to either 2 or None." + "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." ) if not is_multiclass and num_classes != implied_classes: raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") @@ -192,8 +193,8 @@ def _check_classification_inputs( if top_k > 1: if preds.shape == target.shape: raise ValueError( - "You have set top_k above 1, but your data is not (multi-dimensional) multi-class" - "with probability predictions." + "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" + " with probability predictions." ) From eaac5d74cffb2980450a95e15e9ce13f13802605 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:42:52 +0100 Subject: [PATCH 08/38] More error message style improvements --- .../metrics/classification/utils.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 398acbf2bc5fc..54bcd840f3621 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -62,23 +62,23 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("`target` has to be an integer tensor") + raise ValueError("`target` has to be an integer tensor.") elif target.min() < 0: - raise ValueError("`target` has to be a non-negative tensor") + raise ValueError("`target` has to be a non-negative tensor.") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if `preds` are integers, they have to be non-negative") + raise ValueError("if `preds` are integers, they have to be non-negative.") if not preds.shape[0] == target.shape[0]: raise ValueError("`preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range") + raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range.") if threshold > 1 or threshold < 0: - raise ValueError("Threshold should be a probability in [0,1]") + raise ValueError("`threshold` should be a probability in [0,1].") if is_multiclass is False and target.max() > 1: raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") @@ -94,8 +94,9 @@ def _check_classification_inputs( f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: - raise ValueError("if `preds` and `target` are of shape (N, ...)" - " and `preds` are floats, `target` should be binary") + raise ValueError( + "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + ) # Get the case if preds.ndim == 1 and preds_float: @@ -111,12 +112,12 @@ def _check_classification_inputs( elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor") + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") if not preds.shape[:-1] == target.shape: if preds.shape[2:] != target.shape[1:]: raise ValueError( "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." ) extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] @@ -128,7 +129,7 @@ def _check_classification_inputs( else: raise ValueError( "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)" + " and `preds` (N, C, ...) or (N, ..., C)." ) if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: @@ -140,7 +141,7 @@ def _check_classification_inputs( # Check that num_classes is consistent if not num_classes: if preds.shape != target.shape and target.max() >= extra_dim_size: - raise ValueError("The highest label in `target` should be smaller than the size of C dimension") + raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") else: if case == "binary": if num_classes > 2: @@ -173,11 +174,11 @@ def _check_classification_inputs( " See Input Types in Metrics documentation." ) if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`") + raise ValueError("The highest label in `target` should be smaller than `num_classes`.") if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`") + raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`") + raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") elif case == "multi-label": if is_multiclass and num_classes != 2: From c1108f0cf81359008f5b462c3ba0bd851fd50a0d Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:44:28 +0100 Subject: [PATCH 09/38] Fix typo in docs --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 407b64d3d2948..831ac67922114 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -230,7 +230,7 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. -When predictions or targets are integers, it is assumed that class labels start at , i.e. +When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types .. code-block:: python From 277769b7c4e3aa84a83b3e12eb724ea3318e44c0 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:50:02 +0100 Subject: [PATCH 10/38] Add more descriptive variable names in utils --- pytorch_lightning/metrics/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 0f71c531fe6ac..b7f8d492ce01d 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -37,14 +37,14 @@ def _flatten(x): def to_onehot( - tensor: torch.Tensor, + label_tensor: torch.Tensor, num_classes: int, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] + label_tensor: dense label tensor, with shape [N, d1, d2, ...] num_classes: number of classes C Output: @@ -57,18 +57,18 @@ def to_onehot( [0, 0, 1, 0], [0, 0, 0, 1]]) """ - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape + dtype, device, shape = label_tensor.dtype, label_tensor.device, label_tensor.shape tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) -def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: +def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ Convert a probability tensor to binary by selecting top-k highest entries. Args: - tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the position defined by the ``dim`` argument topk: number of highest entries to turn into 1s dim: dimension on which to compare entries @@ -82,8 +82,8 @@ def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tens tensor([[0, 1, 1], [1, 1, 0]], dtype=torch.int32) """ - zeros = torch.zeros_like(tensor, device=tensor.device) - topk_tensor = zeros.scatter(1, tensor.topk(k=topk, dim=dim).indices, 1.0) + zeros = torch.zeros_like(prob_tensor, device=prob_tensor.device) + topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() From 484929861c54922cf8f748eaf677efae4cff5fcb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:59:31 +0100 Subject: [PATCH 11/38] Change internal var names --- tests/metrics/classification/test_inputs.py | 118 ++++++++++---------- 1 file changed, 57 insertions(+), 61 deletions(-) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 19828dc07eb34..6ec6c8dbbc498 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -43,47 +43,43 @@ T = torch.Tensor -def idn(x): +def _idn(x): return x -def usq(x): +def _usq(x): return x.unsqueeze(-1) -def toint(x): - return x.int() - - -def thrs(x): +def _thrs(x): return x >= THRESHOLD -def rshp1(x): +def _rshp1(x): return x.reshape(x.shape[0], -1) -def rshp2(x): +def _rshp2(x): return x.reshape(x.shape[0], x.shape[1], -1) -def onehot(x): +def _onehot(x): return to_onehot(x, NUM_CLASSES) -def onehot2(x): +def _onehot2(x): return to_onehot(x, 2) -def top1(x): +def _top1(x): return select_topk(x, 1) -def top2(x): +def _top2(x): return select_topk(x, 2) -def mvdim(x): +def _mvdim(x): """ Equivalent of torch.movedim(x, -1, 1) """ shape_permute = list(range(x.ndim)) shape_permute[1] = shape_permute[-1] @@ -93,44 +89,44 @@ def mvdim(x): # To avoid ugly black line wrapping -def ml_preds_tr(x): - return rshp1(toint(thrs(x))) +def _ml_preds_tr(x): + return _rshp1(_thrs(x).int()) -def onehot_rshp1(x): - return onehot(rshp1(x)) +def _onehot_rshp1(x): + return _onehot(_rshp1(x)) -def onehot2_rshp1(x): - return onehot2(rshp1(x)) +def _onehot2_rshp1(x): + return _onehot2(_rshp1(x)) -def top1_rshp2(x): - return top1(rshp2(x)) +def _top1_rshp2(x): + return _top1(_rshp2(x)) -def top2_rshp2(x): - return top2(rshp2(x)) +def _top2_rshp2(x): + return _top2(_rshp2(x)) -def mdmc1_top1_tr(x): - return top1(rshp2(mvdim(x))) +def _mdmc1_top1_tr(x): + return _top1(_rshp2(_mvdim(x))) -def mdmc1_top2_tr(x): - return top2(rshp2(mvdim(x))) +def _mdmc1_top2_tr(x): + return _top2(_rshp2(_mvdim(x))) -def probs_to_mc_preds_tr(x): - return toint(onehot2(thrs(x))) +def _probs_to_mc_preds_tr(x): + return _onehot2(_thrs(x)).int() -def mlmd_prob_to_mc_preds_tr(x): - return onehot2(rshp1(toint(thrs(x)))) +def _mlmd_prob_to_mc_preds_tr(x): + return _onehot2(_rshp1(_thrs(x).int())) -def mdmc_prob_to_ml_preds_tr(x): - return top1(mvdim(x))[:, 1] +def _mdmc_prob_to__ml_preds_tr(x): + return _top1(_mvdim(x))[:, 1] ######################## @@ -143,44 +139,44 @@ def mdmc_prob_to_ml_preds_tr(x): [ ############################# # Test usual expected cases - (_bin, THRESHOLD, None, False, 1, "multi-class", usq, usq), - (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: usq(toint(thrs(x))), usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: toint(thrs(x)), idn), - (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", idn, idn), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", ml_preds_tr, rshp1), - (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", rshp1, rshp1), - (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", onehot, onehot), - (_mc_prob, THRESHOLD, None, None, 1, "multi-class", top1, onehot), - (_mc_prob, THRESHOLD, None, None, 2, "multi-class", top2, onehot), - (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", onehot, onehot), - (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot), - (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot_rshp1), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot_rshp1), + (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x).int()), _usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x).int(), _idn), + (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", _onehot, _onehot), + (_mc_prob, THRESHOLD, None, None, 1, "multi-class", _top1, _onehot), + (_mc_prob, THRESHOLD, None, None, 2, "multi-class", _top2, _onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), # Test with C dim in last place - (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot), - (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot_rshp1), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot_rshp1), + (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot), + (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot_rshp1), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass - (_bin, THRESHOLD, None, None, 1, "multi-class", onehot2, onehot2), + (_bin, THRESHOLD, None, None, 1, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, THRESHOLD, None, True, 1, "binary", probs_to_mc_preds_tr, onehot2), + (_bin_prob, THRESHOLD, None, True, 1, "binary", _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2, onehot2), + (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, THRESHOLD, None, True, 1, "multi-label", probs_to_mc_preds_tr, onehot2), + (_ml_prob, THRESHOLD, None, True, 1, "multi-label", _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2_rshp1, onehot2_rshp1), + (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", mlmd_prob_to_mc_preds_tr, onehot2_rshp1), + (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: top1(x)[:, [1]], usq), + (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: top1(x)[:, 1], idn), - (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", mdmc_prob_to_ml_preds_tr, idn), + (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", _mdmc_prob_to__ml_preds_tr, _idn), ], ) def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): From 02bd636b35306e01b427876fb2e80a3291154ebe Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 10:28:23 +0100 Subject: [PATCH 12/38] Break down error checking for inputs into separate functions --- .../metrics/classification/utils.py | 225 +++++++++++------- 1 file changed, 137 insertions(+), 88 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 54bcd840f3621..50892af13ba9e 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -18,6 +18,137 @@ from pytorch_lightning.metrics.utils import to_onehot, select_topk +def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: + """ + This checks that the shape and type of inputs are consistent with + each other and fall into one of the allowed input types (see the + documentation of docstring of _input_format_classification). It does + not check for consistency of number of classes, other functions take + care of that. + + It returns the name of the case in which the inputs fall, and the implied + number of classes (from the C dim for multi-class data, or extra dim(s) for + multi-label data). + """ + + preds_float = preds.is_floating_point() + + if preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "`preds` and `target` should have the same shape", + f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", + ) + if preds_float and target.max() > 1: + raise ValueError( + "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + ) + + # Get the case + if preds.ndim == 1 and preds_float: + case = "binary" + elif preds.ndim == 1 and not preds_float: + case = "multi-class" + elif preds.ndim > 1 and preds_float: + case = "multi-label" + else: + case = "multi-dim multi-class" + + implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) + + elif preds.ndim == target.ndim + 1: + if not preds_float: + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if not preds.shape[:-1] == target.shape: + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "if `preds` have one dimension more than `target`, the shape of `preds` should be" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." + ) + + implied_classes = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] + + if preds.ndim == 2: + case = "multi-class" + else: + case = "multi-dim multi-class" + else: + raise ValueError( + "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" + " and `preds` (N, C, ...) or (N, ..., C)." + ) + + return case, implied_classes + + +def _check_num_classes_binary(num_classes: int, is_multiclass: bool): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for binary data. + """ + + if num_classes > 2: + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") + elif num_classes == 2 and not is_multiclass: + raise ValueError( + "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." + ) + elif num_classes == 1 and is_multiclass: + raise ValueError( + "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." + " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." + ) + + +def _check_num_classes_mc( + preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int +): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for (multi-dimensional) multi-class data. + """ + + if num_classes == 1 and is_multiclass is not False: + raise ValueError( + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `is_multiclass=False`." + ) + elif num_classes > 1: + if is_multiclass is False: + if implied_classes != num_classes: + raise ValueError( + "You have set `is_multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." + ) + if num_classes <= target.max(): + raise ValueError("The highest label in `target` should be smaller than `num_classes`.") + if num_classes <= preds.max(): + raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") + if preds.shape != target.shape and num_classes != implied_classes: + raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") + + +def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for multi-label data. + """ + + if is_multiclass and num_classes != 2: + raise ValueError( + "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." + ) + if not is_multiclass and num_classes != implied_classes: + raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + + def _check_classification_inputs( preds: torch.Tensor, target: torch.Tensor, @@ -87,52 +218,9 @@ def _check_classification_inputs( raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") # Check that shape/types fall into one of the cases - if preds.ndim == target.ndim: - if preds.shape != target.shape: - raise ValueError( - "`preds` and `target` should have the same shape", - f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", - ) - if preds_float and target.max() > 1: - raise ValueError( - "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." - ) - - # Get the case - if preds.ndim == 1 and preds_float: - case = "binary" - elif preds.ndim == 1 and not preds_float: - case = "multi-class" - elif preds.ndim > 1 and preds_float: - case = "multi-label" - else: - case = "multi-dim multi-class" - - implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) - - elif preds.ndim == target.ndim + 1: - if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if not preds.shape[:-1] == target.shape: - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." - ) - - extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] - - if preds.ndim == 2: - case = "multi-class" - else: - case = "multi-dim multi-class" - else: - raise ValueError( - "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)." - ) + case, implied_classes = _check_shape_and_type_consistency(preds, target) - if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: + if preds.shape != target.shape and is_multiclass is False and implied_classes != 2: raise ValueError( "You have set `is_multiclass=False`, but have more than 2 classes in your data," " based on the C dimension of `preds`." @@ -140,55 +228,16 @@ def _check_classification_inputs( # Check that num_classes is consistent if not num_classes: - if preds.shape != target.shape and target.max() >= extra_dim_size: + if preds.shape != target.shape and target.max() >= implied_classes: raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") else: if case == "binary": - if num_classes > 2: - raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - elif num_classes == 2 and not is_multiclass: - raise ValueError( - "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." - " Set it to True if you want to transform binary data to multi-class format." - ) - elif num_classes == 1 and is_multiclass: - raise ValueError( - "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." - ) + _check_num_classes_binary(num_classes, is_multiclass) elif "multi-class" in case: - if num_classes == 1 and is_multiclass is not False: - raise ValueError( - "You have set `num_classes=1`, but predictions are integers." - " If you want to convert (multi-dimensional) multi-class data with 2 classes" - " to binary/multi-label, set `is_multiclass=False`." - ) - elif num_classes > 1: - if is_multiclass is False: - if implied_classes != num_classes: - raise ValueError( - "You have set `is_multiclass=False`, but the implied number of classes " - " (from shape of inputs) does not match `num_classes`. If you are trying to" - " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" - " should be either None or the product of the size of extra dimensions (...)." - " See Input Types in Metrics documentation." - ) - if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`.") - if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") - if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") + _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) elif case == "multi-label": - if is_multiclass and num_classes != 2: - raise ValueError( - "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." - " If you are trying to transform multi-label data to 2 class multi-dimensional" - " multi-class, you should set `num_classes` to either 2 or None." - ) - if not is_multiclass and num_classes != implied_classes: - raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + _check_num_classes_ml(num_classes, is_multiclass, implied_classes) # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities if top_k > 1: From f97145bbc599aa5bf75ba08a3567153fb2508086 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 10:53:37 +0100 Subject: [PATCH 13/38] Remove the (N, ..., C) option in MD-MC --- docs/source/metrics.rst | 2 +- .../metrics/classification/utils.py | 39 ++++++------------- tests/metrics/classification/inputs.py | 6 --- tests/metrics/classification/test_inputs.py | 35 ----------------- 4 files changed, 13 insertions(+), 69 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 831ac67922114..082e84e8c79f4 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -224,7 +224,7 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" - "Multi-dimensional multi-class with probabilities", "(N, C, ...) or (N, ..., C)", "``float``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" .. note:: All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 50892af13ba9e..d665bd88266b8 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -59,14 +59,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) elif preds.ndim == target.ndim + 1: if not preds_float: raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if not preds.shape[:-1] == target.shape: - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." - ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "if `preds` have one dimension more than `target`, the shape of `preds` should be" + " of shape (N, C, ...), and `target` of shape (N, ...)." + ) - implied_classes = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] + implied_classes = preds.shape[1] if preds.ndim == 2: case = "multi-class" @@ -263,15 +262,15 @@ def _input_format_classification( * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) * Both preds and target are of shape ``(N,)``, and target is binary, while preds - are a float (binary) + are a float (binary) * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and - is integer (multi-class) + is integer (multi-class) * preds and target are of shape ``(N, ...)``, target is binary and preds is a float - (multi-label) - * preds are of shape ``(N, ..., C)`` or ``(N, C, ...)`` and are floats, target is of - shape ``(N, ...)`` and is integer (multi-dimensional multi-class) + (multi-label) + * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` + and is integer (multi-dimensional multi-class) * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional - multi-class) + multi-class) To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. @@ -302,13 +301,6 @@ def _input_format_classification( ``is_multiclass=False`` (and there are up to two classes), then the data is returned as ``(N, X)`` binary tensors (multi-label). - Also, in multi-dimensional multi-class case, if the position of the ``C`` - dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a - ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. - If this is not the case, you should move it from the last to second place using - ``torch.movedim(preds, -1, 1)``, or using ``preds.permute``, if you are using an older - version of Pytorch. - Note that where a one-hot transformation needs to be performed and the number of classes is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be equal to ``num_classes``, if it is given, or the maximum label value in preds and @@ -420,13 +412,6 @@ def _input_format_classification( # Multi-dim multi-class (N, C, ...) and (N, ..., C) else: mode = "multi-dim multi-class" - if preds.shape[:-1] == target.shape: - shape_permute = list(range(preds.ndim)) - shape_permute[1] = shape_permute[-1] - shape_permute[2:] = range(1, len(shape_permute) - 1) - - preds = preds.permute(*shape_permute) - num_classes = preds.shape[1] if is_multiclass is False: diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index e648aaf10093e..48d3e85e3afeb 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -70,12 +70,6 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) -# Class dimension last -_multidim_multiclass_prob_inputs1 = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) - _multidim_multiclass_inputs = Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 6ec6c8dbbc498..edce1232863f1 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -12,7 +12,6 @@ _multiclass_prob_inputs as _mc_prob, _multidim_multiclass_inputs as _mdmc, _multidim_multiclass_prob_inputs as _mdmc_prob, - _multidim_multiclass_prob_inputs1 as _mdmc_prob1, _multilabel_inputs as _ml, _multilabel_prob_inputs as _ml_prob, _multilabel_multidim_inputs as _mlmd, @@ -28,16 +27,9 @@ rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), ) -_mdmc_prob_many_dims1 = Input( - rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM, NUM_CLASSES), - randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), -) _mdmc_prob_2cls = Input( rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) -_mdmc_prob_2cls1 = Input( - rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) # Some utils T = torch.Tensor @@ -79,15 +71,6 @@ def _top2(x): return select_topk(x, 2) -def _mvdim(x): - """ Equivalent of torch.movedim(x, -1, 1) """ - shape_permute = list(range(x.ndim)) - shape_permute[1] = shape_permute[-1] - shape_permute[2:] = range(1, len(shape_permute) - 1) - - return x.permute(*shape_permute) - - # To avoid ugly black line wrapping def _ml_preds_tr(x): return _rshp1(_thrs(x).int()) @@ -109,14 +92,6 @@ def _top2_rshp2(x): return _top2(_rshp2(x)) -def _mdmc1_top1_tr(x): - return _top1(_rshp2(_mvdim(x))) - - -def _mdmc1_top2_tr(x): - return _top2(_rshp2(_mvdim(x))) - - def _probs_to_mc_preds_tr(x): return _onehot2(_thrs(x)).int() @@ -125,10 +100,6 @@ def _mlmd_prob_to_mc_preds_tr(x): return _onehot2(_rshp1(_thrs(x).int())) -def _mdmc_prob_to__ml_preds_tr(x): - return _top1(_mvdim(x))[:, 1] - - ######################## # Test correct inputs ######################## @@ -153,11 +124,6 @@ def _mdmc_prob_to__ml_preds_tr(x): (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), - # Test with C dim in last place - (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot), - (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot_rshp1), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass @@ -176,7 +142,6 @@ def _mdmc_prob_to__ml_preds_tr(x): (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), - (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", _mdmc_prob_to__ml_preds_tr, _idn), ], ) def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): From 536feafd8945b78b3c781d9f5c96f3433fa8be8c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 10:55:12 +0100 Subject: [PATCH 14/38] Simplify select_topk --- pytorch_lightning/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index b7f8d492ce01d..e78ba055665d8 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -82,7 +82,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch tensor([[0, 1, 1], [1, 1, 0]], dtype=torch.int32) """ - zeros = torch.zeros_like(prob_tensor, device=prob_tensor.device) + zeros = torch.zeros_like(prob_tensor) topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() From 4241d7c281a160c10178a34375781f019348657a Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 11:50:55 +0100 Subject: [PATCH 15/38] Remove detach for inputs --- pytorch_lightning/metrics/classification/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index d665bd88266b8..16d51d5c35683 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -324,8 +324,6 @@ def _input_format_classification( preds: binary tensor of shape (N, C) or (N, C, X) target: binary tensor of shape (N, C) or (N, C, X) """ - preds, target = preds.clone().detach(), target.clone().detach() - # Remove excess dimensions if preds.shape[0] == 1: preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) From 86d6c4d976cd6201652feaca8a5cb91554cb3603 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 14:09:43 +0100 Subject: [PATCH 16/38] Fix typos --- pytorch_lightning/metrics/classification/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 16d51d5c35683..eb73987187820 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -36,8 +36,8 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError( - "`preds` and `target` should have the same shape", - f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", + "`preds` and `target` should have the same shape,", + f" got `preds` shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: raise ValueError( @@ -62,7 +62,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.shape[2:] != target.shape[1:]: raise ValueError( "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " of shape (N, C, ...), and `target` of shape (N, ...)." + " (N, C, ...), and the shape of `target` should be (N, ...)." ) implied_classes = preds.shape[1] @@ -74,7 +74,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) else: raise ValueError( "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)." + " and `preds` (N, C, ...)." ) return case, implied_classes From cde39970fc79af4dc6adba981957ddcdd5b2628b Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 21:34:33 +0100 Subject: [PATCH 17/38] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index eb73987187820..9a49c7c4e8cdb 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -22,7 +22,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) """ This checks that the shape and type of inputs are consistent with each other and fall into one of the allowed input types (see the - documentation of docstring of _input_format_classification). It does + documentation of docstring of ``_input_format_classification``). It does not check for consistency of number of classes, other functions take care of that. From 05a54da4df1ab72d1ec0424d4bd3af87f8e6bee5 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 21:40:54 +0100 Subject: [PATCH 18/38] Update docs/source/metrics.rst Co-authored-by: Jirka Borovec --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 082e84e8c79f4..b59fdc6c73009 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -233,7 +233,7 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types -.. code-block:: python +.. testcode:: # Binary inputs binary_preds = torch.tensor([0.6, 0.1, 0.9]) From 9a43a5eafe106db04652b2e80bbcf64f0f0308f7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 23:18:41 +0100 Subject: [PATCH 19/38] Minor error message changes --- .../metrics/classification/utils.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 9a49c7c4e8cdb..9ddcfb9132c2c 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -36,12 +36,12 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError( - "`preds` and `target` should have the same shape,", + "The `preds` and `target` should have the same shape,", f" got `preds` shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: raise ValueError( - "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + "If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." ) # Get the case @@ -58,10 +58,10 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") if preds.shape[2:] != target.shape[1:]: raise ValueError( - "if `preds` have one dimension more than `target`, the shape of `preds` should be" + "If `preds` have one dimension more than `target`, the shape of `preds` should be" " (N, C, ...), and the shape of `target` should be (N, ...)." ) @@ -73,7 +73,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) case = "multi-dim multi-class" else: raise ValueError( - "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" + "The `preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" " and `preds` (N, C, ...)." ) @@ -192,23 +192,23 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("`target` has to be an integer tensor.") + raise ValueError("The `target` has to be an integer tensor.") elif target.min() < 0: - raise ValueError("`target` has to be a non-negative tensor.") + raise ValueError("The `target` has to be a non-negative tensor.") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if `preds` are integers, they have to be non-negative.") + raise ValueError("If `preds` are integers, they have to be non-negative.") if not preds.shape[0] == target.shape[0]: - raise ValueError("`preds` and `target` should have the same first dimension.") + raise ValueError("The `preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range.") + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") if threshold > 1 or threshold < 0: - raise ValueError("`threshold` should be a probability in [0,1].") + raise ValueError("The `threshold` should be a probability in [0,1].") if is_multiclass is False and target.max() > 1: raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") From 3f4ad3c5a25bc82ad61a671701202f4b269d3a6c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 23:19:10 +0100 Subject: [PATCH 20/38] Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec --- pytorch_lightning/metrics/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e78ba055665d8..170315aa22236 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -57,8 +57,13 @@ def to_onehot( [0, 0, 1, 0], [0, 0, 0, 1]]) """ - dtype, device, shape = label_tensor.dtype, label_tensor.device, label_tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) + tensor_onehot = torch.zeros( + label_tensor.shape[0], + num_classes, + *label_tensor.shape[1:], + dtype=label_tensor.dtype, + device=label_tensor.device, + ) index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) From a654e6a9b70322c70232b810f8c156ff3d055138 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 23:27:10 +0100 Subject: [PATCH 21/38] Reuse case from validation in formatting --- .../metrics/classification/utils.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 9ddcfb9132c2c..4183130704e8f 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -155,7 +155,7 @@ def _check_classification_inputs( num_classes: Optional[int] = None, is_multiclass: bool = False, top_k: int = 1, -) -> None: +) -> str: """Performs error checking on inputs for classification. This ensures that preds and target take one of the shape/type combinations that are @@ -189,6 +189,10 @@ def _check_classification_inputs( multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. Defaults to None, which treats inputs as they appear. + + Return: + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' """ if target.is_floating_point(): @@ -246,6 +250,8 @@ def _check_classification_inputs( " with probability predictions." ) + return case + def _input_format_classification( preds: torch.Tensor, @@ -276,7 +282,7 @@ def _input_format_classification( The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` of ``(N, C, X)``, the details for each case are described below. The function also returns - a ``mode`` string, which describes which of the above cases the inputs belonged to - regardless + a ``case`` string, which describes which of the above cases the inputs belonged to - regardless of whether this was "overridden" by other settings (like ``is_multiclass``). In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed @@ -323,6 +329,8 @@ def _input_format_classification( Returns: preds: binary tensor of shape (N, C) or (N, C, X) target: binary tensor of shape (N, C) or (N, C, X) + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' """ # Remove excess dimensions if preds.shape[0] == 1: @@ -330,7 +338,7 @@ def _input_format_classification( else: preds, target = preds.squeeze(), target.squeeze() - _check_classification_inputs( + case = _check_classification_inputs( preds, target, threshold=threshold, @@ -341,8 +349,7 @@ def _input_format_classification( preds_float = preds.is_floating_point() - if preds.ndim == target.ndim == 1 and preds_float: - mode = "binary" + if case == "binary": preds = (preds >= threshold).int() if is_multiclass: @@ -352,8 +359,7 @@ def _input_format_classification( preds = preds.unsqueeze(-1) target = target.unsqueeze(-1) - elif preds.ndim == target.ndim and preds_float: - mode = "multi-label" + elif case == "multi-label": preds = (preds >= threshold).int() if is_multiclass: @@ -363,8 +369,8 @@ def _input_format_classification( preds = preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) + # Multi-class with probabilities elif preds.ndim == target.ndim + 1 == 2: - mode = "multi-class" if not num_classes: num_classes = preds.shape[1] @@ -376,9 +382,8 @@ def _input_format_classification( target = target[:, [1]] preds = preds[:, [1]] + # Multi-class with labels elif preds.ndim == target.ndim == 1 and not preds_float: - mode = "multi-class" - if not num_classes: num_classes = max(preds.max(), target.max()) + 1 @@ -392,8 +397,6 @@ def _input_format_classification( # Multi-dim multi-class (N, ...) with integers elif preds.shape == target.shape and not preds_float: - mode = "multi-dim multi-class" - if not num_classes: num_classes = max(preds.max(), target.max()) + 1 @@ -407,9 +410,8 @@ def _input_format_classification( preds = to_onehot(preds, num_classes) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) - # Multi-dim multi-class (N, C, ...) and (N, ..., C) + # Multi-dim multi-class (N, C, ...) else: - mode = "multi-dim multi-class" num_classes = preds.shape[1] if is_multiclass is False: @@ -421,4 +423,4 @@ def _input_format_classification( target = target.reshape(target.shape[0], target.shape[1], -1) preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) - return preds, target, mode + return preds, target, case From 16ab8f784b0bdbc6c2b278678c44eacd62093c4e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 01:01:27 +0100 Subject: [PATCH 22/38] Refactor code in _input_format_classification --- .../metrics/classification/utils.py | 90 +++++-------------- tests/metrics/classification/test_inputs.py | 24 ++--- 2 files changed, 36 insertions(+), 78 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 4183130704e8f..c224f3f1caa18 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -249,6 +249,8 @@ def _check_classification_inputs( "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" " with probability predictions." ) + if is_multiclass is False: + raise ValueError("If you set `is_multiclass` to False, you can not set `top_k` above 1.") return case @@ -330,7 +332,7 @@ def _input_format_classification( preds: binary tensor of shape (N, C) or (N, C, X) target: binary tensor of shape (N, C) or (N, C, X) case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or - 'multi-dim multi-class' + 'multi-dim multi-class' """ # Remove excess dimensions if preds.shape[0] == 1: @@ -347,80 +349,34 @@ def _input_format_classification( top_k=top_k, ) - preds_float = preds.is_floating_point() - - if case == "binary": - preds = (preds >= threshold).int() - - if is_multiclass: - target = to_onehot(target, 2) - preds = to_onehot(preds, 2) - else: - preds = preds.unsqueeze(-1) - target = target.unsqueeze(-1) - - elif case == "multi-label": + if case in ["binary", "multi-label"]: preds = (preds >= threshold).int() + num_classes = num_classes if not is_multiclass else 2 - if is_multiclass: - preds = to_onehot(preds, 2).reshape(preds.shape[0], 2, -1) - target = to_onehot(target, 2).reshape(target.shape[0], 2, -1) - else: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - - # Multi-class with probabilities - elif preds.ndim == target.ndim + 1 == 2: - if not num_classes: + if "multi-class" in case or is_multiclass: + if preds.is_floating_point(): num_classes = preds.shape[1] - - target = to_onehot(target, num_classes) - preds = select_topk(preds, top_k) - - # If is_multiclass=False, force to binary - if is_multiclass is False: - target = target[:, [1]] - preds = preds[:, [1]] - - # Multi-class with labels - elif preds.ndim == target.ndim == 1 and not preds_float: - if not num_classes: - num_classes = max(preds.max(), target.max()) + 1 - - # If is_multiclass=False, force to binary - if is_multiclass is False: - preds = preds.unsqueeze(1) - target = target.unsqueeze(1) + preds = select_topk(preds, top_k) else: + num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 preds = to_onehot(preds, num_classes) - target = to_onehot(target, num_classes) - # Multi-dim multi-class (N, ...) with integers - elif preds.shape == target.shape and not preds_float: - if not num_classes: - num_classes = max(preds.max(), target.max()) + 1 + target = to_onehot(target, num_classes) - # If is_multiclass=False, force to multi-label if is_multiclass is False: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - else: - target = to_onehot(target, num_classes) - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = to_onehot(preds, num_classes) - preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + preds, target = preds[:, 1, ...], target[:, 1, ...] - # Multi-dim multi-class (N, C, ...) - else: - num_classes = preds.shape[1] + if (case in ["binary", "multi-label"] and not is_multiclass) or is_multiclass is False: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) - if is_multiclass is False: - target = target.reshape(target.shape[0], -1) - preds = select_topk(preds, 1)[:, 1, ...] - preds = preds.reshape(preds.shape[0], -1) - else: - target = to_onehot(target, num_classes) - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) + elif "multi-class" in case or is_multiclass: + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + + # Some operatins above create an extra dimension for MC/binary case - this removes it + if preds.ndim > 2: + preds = preds.squeeze(-1) + target = target.squeeze(-1) - return preds, target, case + return preds.int(), target.int(), case diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index edce1232863f1..79788a1c2ecf4 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -73,7 +73,7 @@ def _top2(x): # To avoid ugly black line wrapping def _ml_preds_tr(x): - return _rshp1(_thrs(x).int()) + return _rshp1(_thrs(x)) def _onehot_rshp1(x): @@ -93,11 +93,11 @@ def _top2_rshp2(x): def _probs_to_mc_preds_tr(x): - return _onehot2(_thrs(x)).int() + return _onehot2(_thrs(x)) def _mlmd_prob_to_mc_preds_tr(x): - return _onehot2(_rshp1(_thrs(x).int())) + return _onehot2(_rshp1(_thrs(x))) ######################## @@ -111,8 +111,8 @@ def _mlmd_prob_to_mc_preds_tr(x): ############################# # Test usual expected cases (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), - (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x).int()), _usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x).int(), _idn), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x), _idn), (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), @@ -155,8 +155,8 @@ def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_m ) assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0])) - assert torch.equal(target_out, post_target(inputs.target[0])) + assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) + assert torch.equal(target_out, post_target(inputs.target[0]).int()) # Test that things work when batch_size = 1 preds_out, target_out, mode = _input_format_classification( @@ -169,8 +169,8 @@ def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_m ) assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...])) - assert torch.equal(target_out, post_target(inputs.target[0][[0], ...])) + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) # Test that threshold is correctly applied @@ -180,7 +180,7 @@ def test_threshold(): preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) - assert torch.equal(torch.tensor([0, 1, 1]), preds_probs_out.squeeze().long()) + assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int()) ######################################################################## @@ -222,7 +222,7 @@ def test_threshold(): # Max target larger or equal to C dimension (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, 1), # C dimension not equal to num_classes - (rand(size=(7, 3, 4)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), + (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), # Max target larger than num_classes (with #dim preds = 1 + #dims target) (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), # Max target larger than num_classes (with #dim preds = #dims target) @@ -253,6 +253,8 @@ def test_threshold(): (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), + # Topk =2 with 2 classes, is_multiclass=False + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], 0.5, None, False, 2), ], ) def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): From ecffe18d3b1a9bf6d4c173af29d673274aaaa426 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 22:48:24 +0100 Subject: [PATCH 23/38] Small improvements --- pytorch_lightning/metrics/classification/utils.py | 9 ++++----- tests/metrics/classification/test_inputs.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index c224f3f1caa18..68b4a2da789da 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -366,13 +366,12 @@ def _input_format_classification( if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] - if (case in ["binary", "multi-label"] and not is_multiclass) or is_multiclass is False: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - - elif "multi-class" in case or is_multiclass: + if ("multi-class" in case and is_multiclass is not False) or is_multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) # Some operatins above create an extra dimension for MC/binary case - this removes it if preds.ndim > 2: diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 79788a1c2ecf4..430d844217cb5 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -112,7 +112,7 @@ def _mlmd_prob_to_mc_preds_tr(x): # Test usual expected cases (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x), _idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _thrs, _idn), (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), From 725c7dd717199b8d35becd77ea1490b7ee3fbddb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:01:01 +0100 Subject: [PATCH 24/38] PEP 8 --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 68b4a2da789da..5335083e36b95 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -364,7 +364,7 @@ def _input_format_classification( target = to_onehot(target, num_classes) if is_multiclass is False: - preds, target = preds[:, 1, ...], target[:, 1, ...] + preds, target = preds[:, 1, ...], target[:, 1, ...] if ("multi-class" in case and is_multiclass is not False) or is_multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) From 41ad0b7163229b7397daeababd1924a10d2db372 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:04:18 +0100 Subject: [PATCH 25/38] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 5335083e36b95..79a382f5104a8 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -54,7 +54,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) else: case = "multi-dim multi-class" - implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) + implied_classes = preds[0].numel() elif preds.ndim == target.ndim + 1: if not preds_float: From ca13e76713c9f5b5bfc5f315e2380040b9b726fb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:06:57 +0100 Subject: [PATCH 26/38] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 79a382f5104a8..2dac8b7b7eae9 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -73,8 +73,8 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) case = "multi-dim multi-class" else: raise ValueError( - "The `preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...)." + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." ) return case, implied_classes From ede2c7fa2ff98aba446591da11b9dac28df2824e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:07:14 +0100 Subject: [PATCH 27/38] Update docs/source/metrics.rst Co-authored-by: Rohit Gupta --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index b59fdc6c73009..4be5e67b7c447 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -251,7 +251,7 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class, +In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label. For example, if both predictions and targets are 1d binary tensors. Or it could be the other way around, you want to treat binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs. From c6e4de44ff93160fcc30467a708946de7185fa10 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:07:35 +0100 Subject: [PATCH 28/38] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 2dac8b7b7eae9..76d4348978cc0 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -37,7 +37,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.shape != target.shape: raise ValueError( "The `preds` and `target` should have the same shape,", - f" got `preds` shape = {preds.shape} and `target` shape = {target.shape}.", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", ) if preds_float and target.max() > 1: raise ValueError( From 201d0debf8026681cf8756f52cbb3bd447935aa5 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:17:24 +0100 Subject: [PATCH 29/38] Apply suggestions from code review Co-authored-by: Rohit Gupta --- docs/source/metrics.rst | 2 +- pytorch_lightning/metrics/classification/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 4be5e67b7c447..6d70b92ca8a9f 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -203,7 +203,7 @@ Class vs Functional Metrics The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. +If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface. ********************** Classification Metrics diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 76d4348978cc0..ac6ec711a2654 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -96,7 +96,7 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool): elif num_classes == 1 and is_multiclass: raise ValueError( "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." + " Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format." ) From f08edbcc6e8eaaa2eec655155db82ade05c0230e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:25:09 +0100 Subject: [PATCH 30/38] Alphabetical reordering of regression metrics --- docs/source/metrics.rst | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index b59fdc6c73009..50dc36fac6b25 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -202,7 +202,7 @@ Class vs Functional Metrics The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. -Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. +Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. ********************** @@ -447,10 +447,10 @@ Regression Metrics Class Metrics (Regression) -------------------------- -MeanSquaredError -~~~~~~~~~~~~~~~~ +ExplainedVariance +~~~~~~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance :noindex: @@ -461,17 +461,17 @@ MeanAbsoluteError :noindex: -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ +MeanSquaredError +~~~~~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError :noindex: -ExplainedVariance -~~~~~~~~~~~~~~~~~ +MeanSquaredLogError +~~~~~~~~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError :noindex: @@ -513,17 +513,17 @@ mean_squared_error [func] :noindex: -psnr [func] -~~~~~~~~~~~ +mean_squared_log_error [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.psnr +.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error :noindex: -mean_squared_log_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +psnr [func] +~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error +.. autofunction:: pytorch_lightning.metrics.functional.psnr :noindex: From 35e3eff9e4f5dcf9bc4e32722becf9fbfba3fc53 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sat, 28 Nov 2020 16:14:16 +0100 Subject: [PATCH 31/38] Change default value of top_k and add error checking --- .../metrics/classification/utils.py | 80 +++++++++------- tests/metrics/classification/test_inputs.py | 96 ++++++++++--------- 2 files changed, 96 insertions(+), 80 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index ac6ec711a2654..af4319152942f 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -148,13 +148,25 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") +def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): + if "multi-class" not in case or not preds_float: + raise ValueError( + "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" + " with probability predictions." + ) + if is_multiclass is False: + raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") + if top_k >= implied_classes: + raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") + + def _check_classification_inputs( preds: torch.Tensor, target: torch.Tensor, threshold: float, - num_classes: Optional[int] = None, - is_multiclass: bool = False, - top_k: int = 1, + num_classes: Optional[int], + is_multiclass: bool, + top_k: Optional[int], ) -> str: """Performs error checking on inputs for classification. @@ -172,8 +184,9 @@ def _check_classification_inputs( When ``num_classes`` is not specified in these cases, consistency of the highest target value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. - If ``top_k`` is larger than one, then an error is raised if the inputs are not (multi-dim) - multi-class with probability predictions. + If ``top_k`` is set (not None) for inputs which are not (multi-dimensional) multi class + with probabilities, then an error is raised. Similarly if ``top_k`` is set to a number + that is higher than or equal to the ``C`` dimension of ``preds``. Preds and target tensors are expected to be squeezed already - all dimensions should be greater than 1, except perhaps the first one (N). @@ -189,6 +202,8 @@ def _check_classification_inputs( multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. Defaults to None, which treats inputs as they appear. + top_k: number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class cases. Return: case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or @@ -197,7 +212,7 @@ def _check_classification_inputs( if target.is_floating_point(): raise ValueError("The `target` has to be an integer tensor.") - elif target.min() < 0: + if target.min() < 0: raise ValueError("The `target` has to be a non-negative tensor.") preds_float = preds.is_floating_point() @@ -207,9 +222,8 @@ def _check_classification_inputs( if not preds.shape[0] == target.shape[0]: raise ValueError("The `preds` and `target` should have the same first dimension.") - if preds_float: - if preds.min() < 0 or preds.max() > 1: - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") + if preds_float and (preds.min() < 0 or preds.max() > 1): + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") if threshold > 1 or threshold < 0: raise ValueError("The `threshold` should be a probability in [0,1].") @@ -223,34 +237,30 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) - if preds.shape != target.shape and is_multiclass is False and implied_classes != 2: - raise ValueError( - "You have set `is_multiclass=False`, but have more than 2 classes in your data," - " based on the C dimension of `preds`." - ) + # Check consistency with the `C` dimension in case of multi-class data + if preds.shape != target.shape: + if is_multiclass is False and implied_classes != 2: + raise ValueError( + "You have set `is_multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." + ) + if target.max() >= implied_classes: + raise ValueError( + "The highest label in `target` should be smaller than the size of the `C` dimension of `preds`." + ) # Check that num_classes is consistent - if not num_classes: - if preds.shape != target.shape and target.max() >= implied_classes: - raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") - else: + if num_classes: if case == "binary": _check_num_classes_binary(num_classes, is_multiclass) elif "multi-class" in case: _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) - elif case == "multi-label": _check_num_classes_ml(num_classes, is_multiclass, implied_classes) - # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities - if top_k > 1: - if preds.shape == target.shape: - raise ValueError( - "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" - " with probability predictions." - ) - if is_multiclass is False: - raise ValueError("If you set `is_multiclass` to False, you can not set `top_k` above 1.") + # Check that top_k is consistent + if top_k: + _check_top_k(top_k, case, implied_classes, is_multiclass, preds_float) return case @@ -259,7 +269,7 @@ def _input_format_classification( preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, - top_k: int = 1, + top_k: Optional[int] = None, num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor, str]: @@ -322,7 +332,10 @@ def _input_format_classification( (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 num_classes: number of classes top_k: number of highest probability entries for each sample to convert to 1s, relevant - only for (multi-dimensional) multi-class cases. + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as one for these inputs. + + Should be left unset (``None``) for all other types of inputs. is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. @@ -349,6 +362,8 @@ def _input_format_classification( top_k=top_k, ) + top_k = top_k if top_k else 1 + if case in ["binary", "multi-label"]: preds = (preds >= threshold).int() num_classes = num_classes if not is_multiclass else 2 @@ -370,12 +385,11 @@ def _input_format_classification( target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) else: - preds = preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) + preds = preds.reshape(preds.shape[0], -1) # Some operatins above create an extra dimension for MC/binary case - this removes it if preds.ndim > 2: - preds = preds.squeeze(-1) - target = target.squeeze(-1) + preds, target = preds.squeeze(-1), target.squeeze(-1) return preds.int(), target.int(), case diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 430d844217cb5..058ec66c10ed6 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -110,38 +110,38 @@ def _mlmd_prob_to_mc_preds_tr(x): [ ############################# # Test usual expected cases - (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), - (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _thrs, _idn), - (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), - (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), - (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", _onehot, _onehot), - (_mc_prob, THRESHOLD, None, None, 1, "multi-class", _top1, _onehot), + (_bin, THRESHOLD, None, False, None, "multi-class", _usq, _usq), + (_bin_prob, THRESHOLD, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, THRESHOLD, None, None, None, "multi-label", _thrs, _idn), + (_ml, THRESHOLD, None, False, None, "multi-dim multi-class", _idn, _idn), + (_ml_prob, THRESHOLD, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, THRESHOLD, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), + (_mc_prob, THRESHOLD, None, None, None, "multi-class", _top1, _onehot), (_mc_prob, THRESHOLD, None, None, 2, "multi-class", _top2, _onehot), - (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", _onehot, _onehot), - (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass - (_bin, THRESHOLD, None, None, 1, "multi-class", _onehot2, _onehot2), + (_bin, THRESHOLD, None, None, None, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, THRESHOLD, None, True, 1, "binary", _probs_to_mc_preds_tr, _onehot2), + (_bin_prob, THRESHOLD, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2, _onehot2), + (_ml, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, THRESHOLD, None, True, 1, "multi-label", _probs_to_mc_preds_tr, _onehot2), + (_ml_prob, THRESHOLD, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), + (_mlmd, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), + (_mlmd_prob, THRESHOLD, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: _top1(x)[:, [1]], _usq), + (_mc_prob_2cls, THRESHOLD, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls, THRESHOLD, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), ], ) def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): @@ -192,59 +192,59 @@ def test_threshold(): "preds, target, threshold, num_classes, is_multiclass, top_k", [ # Target not integer - (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, 1), + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, None), # Target negative - (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, 1), + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, None), # Preds negative integers - (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), # Negative probabilities - (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), # Threshold outside of [0,1] - (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, None), # is_multiclass=False and target > 1 - (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, 1), + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, None), # is_multiclass=False and preds integers with > 1 - (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, 1), + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, None), # Wrong batch size - (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, None), # Completely wrong shape - (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, None), # Same #dims, different shape - (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, None), # Same shape and preds floats, target not binary - (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, 1), + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, None), # #dims in preds = 1 + #dims in target, C shape not second or last - (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # #dims in preds = 1 + #dims in target, preds not float - (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # is_multiclass=False, with C dimension > 2 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, 1), + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, None), # Max target larger or equal to C dimension - (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, 1), + (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, None), # C dimension not equal to num_classes - (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), + (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, None), # Max target larger than num_classes (with #dim preds = 1 + #dims target) - (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), # Max target larger than num_classes (with #dim preds = #dims target) - (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), # Max preds larger than num_classes (with #dim preds = #dims target) - (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, 1), + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, None), # Num_classes=1, but is_multiclass not false (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, 1, None, 1), # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes - (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), # Multilabel input with implied class dimension != num_classes - (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) - (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, 1), + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, None), # Binary input, num_classes > 2 - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, None), # Binary input, num_classes == 2 and is_multiclass not True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, 1), - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, None), # Binary input, num_classes == 1 and is_multiclass=True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, 1), - # Topk > 1 with non (md)mc prob data + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, None), + # Topk set with non (md)mc prob data (_bin.preds[0], _bin.target[0], 0.5, None, None, 2), (_bin_prob.preds[0], _bin_prob.target[0], 0.5, None, None, 2), (_mc.preds[0], _mc.target[0], 0.5, None, None, 2), @@ -253,8 +253,10 @@ def test_threshold(): (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), - # Topk =2 with 2 classes, is_multiclass=False + # top_k =2 with 2 classes, is_multiclass=False (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], 0.5, None, False, 2), + # top_k = number of classes (C dimension) + (_mc_prob.preds[0], _mc_prob.target[0], 0.5, None, None, NUM_CLASSES), ], ) def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): From c28aadf7e9774419326462949d5f371812c6bed4 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sat, 28 Nov 2020 16:23:27 +0100 Subject: [PATCH 32/38] Extract basic validation into separate function --- .../metrics/classification/utils.py | 64 +++++++++++-------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index af4319152942f..eda929fdd32ed 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -18,6 +18,37 @@ from pytorch_lightning.metrics.utils import to_onehot, select_topk +def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): + """ + Perform basic validation of inputs that does not require deducing any information + of the type of inputs. + """ + + if target.is_floating_point(): + raise ValueError("The `target` has to be an integer tensor.") + if target.min() < 0: + raise ValueError("The `target` has to be a non-negative tensor.") + + preds_float = preds.is_floating_point() + if not preds_float and preds.min() < 0: + raise ValueError("If `preds` are integers, they have to be non-negative.") + + if not preds.shape[0] == target.shape[0]: + raise ValueError("The `preds` and `target` should have the same first dimension.") + + if preds_float and (preds.min() < 0 or preds.max() > 1): + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") + + if threshold > 1 or threshold < 0: + raise ValueError("The `threshold` should be a probability in [0,1].") + + if is_multiclass is False and target.max() > 1: + raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") + + if is_multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") + + def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: """ This checks that the shape and type of inputs are consistent with @@ -88,12 +119,12 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool): if num_classes > 2: raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - elif num_classes == 2 and not is_multiclass: + if num_classes == 2 and not is_multiclass: raise ValueError( "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." " Set it to True if you want to transform binary data to multi-class format." ) - elif num_classes == 1 and is_multiclass: + if num_classes == 1 and is_multiclass: raise ValueError( "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." " Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format." @@ -114,7 +145,7 @@ def _check_num_classes_mc( " If you want to convert (multi-dimensional) multi-class data with 2 classes" " to binary/multi-label, set `is_multiclass=False`." ) - elif num_classes > 1: + if num_classes > 1: if is_multiclass is False: if implied_classes != num_classes: raise ValueError( @@ -210,29 +241,8 @@ def _check_classification_inputs( 'multi-dim multi-class' """ - if target.is_floating_point(): - raise ValueError("The `target` has to be an integer tensor.") - if target.min() < 0: - raise ValueError("The `target` has to be a non-negative tensor.") - - preds_float = preds.is_floating_point() - if not preds_float and preds.min() < 0: - raise ValueError("If `preds` are integers, they have to be non-negative.") - - if not preds.shape[0] == target.shape[0]: - raise ValueError("The `preds` and `target` should have the same first dimension.") - - if preds_float and (preds.min() < 0 or preds.max() > 1): - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") - - if threshold > 1 or threshold < 0: - raise ValueError("The `threshold` should be a probability in [0,1].") - - if is_multiclass is False and target.max() > 1: - raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") - - if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") + # Baisc validation (that does not need case/type information) + _basic_input_validation(preds, target, threshold, is_multiclass) # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) @@ -260,7 +270,7 @@ def _check_classification_inputs( # Check that top_k is consistent if top_k: - _check_top_k(top_k, case, implied_classes, is_multiclass, preds_float) + _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) return case From 0cb0eac3b1602ae9cacf094aebb10c20ec01bc90 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 29 Nov 2020 18:33:11 +0100 Subject: [PATCH 33/38] Update desciption of parameters in input formatting --- .../metrics/classification/utils.py | 72 ++++++++++++++----- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index eda929fdd32ed..8c4c6b6eb94a9 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -223,18 +223,37 @@ def _check_classification_inputs( greater than 1, except perhaps the first one (N). Args: - preds: tensor with predictions - target: tensor with ground truth labels, always integers + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - num_classes: number of classes - is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim - multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim - multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. - Defaults to None, which treats inputs as they appear. - top_k: number of highest probability entries for each sample to convert to 1s, relevant - only for (multi-dimensional) multi-class cases. + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as 1 for these inputs. + + Should be left unset (``None``) for all other types of inputs. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be (see :ref:`metrics: Input types` documentation section for + input classification and examples of the use of this parameter). Should be left at default + value (``None``) in most cases. + + The special cases where this parameter should be set are: + + - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional + multi-class with 2 classes, respectively. The probabilities are interpreted as the + probability of the "1" class, and thresholding still applies as usual. In this case + the parameter should be set to ``True``. + - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes + as binary or multi-label inputs, respectively. This is mainly meant for the case when + inputs are labels, but will work if they are probabilities as well. For this case the + parameter should be set to ``False``. Return: case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or @@ -335,21 +354,38 @@ def _input_format_classification( target. Args: - preds: tensor with predictions - target: tensor with ground truth labels, always integers + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - num_classes: number of classes - top_k: number of highest probability entries for each sample to convert to 1s, relevant + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interepreted as one for these inputs. + default value (``None``) will be interepreted as 1 for these inputs. Should be left unset (``None``) for all other types of inputs. - is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim - multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim - multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. - Defaults to None, which treats inputs as they appear. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be (see :ref:`metrics: Input types` documentation section for + input classification and examples of the use of this parameter). Should be left at default + value (``None``) in most cases. + + The special cases where this parameter should be set are: + + - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional + multi-class with 2 classes, respectively. The probabilities are interpreted as the + probability of the "1" class, and thresholding still applies as usual. In this case + the parameter should be set to ``True``. + - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes + as binary or multi-label inputs, respectively. This is mainly meant for the case when + inputs are labels, but will work if they are probabilities as well. For this case the + parameter should be set to ``False``. + Returns: preds: binary tensor of shape (N, C) or (N, C, X) From 8e7a85a78568160f8c6e94bf01dbd76c8c064e5c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 30 Nov 2020 23:27:03 +0100 Subject: [PATCH 34/38] Apply suggestions from code review Co-authored-by: Nicki Skafte --- pytorch_lightning/metrics/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 170315aa22236..c8505bba58e3b 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -121,7 +121,7 @@ def _input_format_classification( # multi class probabilites preds = torch.argmax(preds, dim=1) - if preds.ndim == target.ndim and preds.dtype == torch.float: + if preds.ndim == target.ndim and preds.is_floating_point(): # binary or multilabel probablities preds = (preds >= threshold).long() return preds, target @@ -151,12 +151,12 @@ def _input_format_classification_one_hot( # multi class probabilites preds = torch.argmax(preds, dim=1) - if preds.ndim == target.ndim and preds.dtype == torch.long and num_classes > 1 and not multilabel: + if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel: # multi-class preds = to_onehot(preds, num_classes=num_classes) target = to_onehot(target, num_classes=num_classes) - elif preds.ndim == target.ndim and preds.dtype == torch.float: + elif preds.ndim == target.ndim and preds.is_floating_point(): # binary or multilabel probablities preds = (preds >= threshold).long() From 829155efcf09e542cded5075cf2b543e588bc62f Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 1 Dec 2020 00:06:02 +0100 Subject: [PATCH 35/38] Check that probabilities in preds sum to 1 (for MC) --- .../metrics/classification/utils.py | 5 +++++ tests/metrics/classification/inputs.py | 9 +++++++-- tests/metrics/classification/test_inputs.py | 18 +++++++++++++----- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 8c4c6b6eb94a9..7cfc505e6c673 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -266,6 +266,11 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) + # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 + if "multi-class" in case and preds.is_floating_point(): + if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): + raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") + # Check consistency with the `C` dimension in case of multi-class data if preds.shape != target.shape: if is_multiclass is False and implied_classes != 2: diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index 48d3e85e3afeb..9f70a80cd31a4 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -53,8 +53,11 @@ target=__temp_target ) +__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True) + _multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) @@ -64,9 +67,11 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) +__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) +__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) _multidim_multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 058ec66c10ed6..8fe3c01fe77c3 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -22,14 +22,20 @@ torch.manual_seed(42) # Some additional inputs to test on -_mc_prob_2cls = Input(rand(NUM_BATCHES, BATCH_SIZE, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) +_mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) +_mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) + +_mdmc_prob_many_dims_preds = rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM) +_mdmc_prob_many_dims_preds /= _mdmc_prob_many_dims_preds.sum(dim=2, keepdim=True) _mdmc_prob_many_dims = Input( - rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM), + _mdmc_prob_many_dims_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), ) -_mdmc_prob_2cls = Input( - rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) + +_mdmc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM) +_mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))) # Some utils T = torch.Tensor @@ -219,6 +225,8 @@ def test_threshold(): (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # is_multiclass=False, with C dimension > 2 (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, None), + # Probs of multiclass preds do not sum up to 1 + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, None, None), # Max target larger or equal to C dimension (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, None), # C dimension not equal to num_classes From 768879db9b13c911b39f9f6ddc9e2f208d2279a7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 1 Dec 2020 00:21:31 +0100 Subject: [PATCH 36/38] Fix coverage --- tests/metrics/classification/test_inputs.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8fe3c01fe77c3..6b5a03fcf1ea6 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -224,15 +224,22 @@ def test_threshold(): # #dims in preds = 1 + #dims in target, preds not float (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # is_multiclass=False, with C dimension > 2 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, None), + (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), 0.5, None, False, None), # Probs of multiclass preds do not sum up to 1 (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, None, None), # Max target larger or equal to C dimension - (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, None), + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), 0.5, None, None, None), # C dimension not equal to num_classes - (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, None), + (_mc_prob.preds[0], _mc_prob.target[0], 0.5, NUM_CLASSES + 1, None, None), # Max target larger than num_classes (with #dim preds = 1 + #dims target) - (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), + ( + _mc_prob.preds[0], + randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), + 0.5, + 4, + None, + None, + ), # Max target larger than num_classes (with #dim preds = #dims target) (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), # Max preds larger than num_classes (with #dim preds = #dims target) From 96d40c87d7de9d30af8a130328efdaac002bedff Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 12:27:41 +0100 Subject: [PATCH 37/38] Minor changes --- docs/source/metrics.rst | 6 ++++++ .../metrics/classification/{utils.py => helpers.py} | 0 tests/metrics/classification/test_inputs.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) rename pytorch_lightning/metrics/classification/{utils.py => helpers.py} (100%) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 6b9dd8307a457..ee141dc74a679 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -424,6 +424,12 @@ recall [func] .. autofunction:: pytorch_lightning.metrics.functional.classification.recall :noindex: +select_topk [func] +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.utils.select_topk + :noindex: + stat_scores [func] ~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/helpers.py similarity index 100% rename from pytorch_lightning/metrics/classification/utils.py rename to pytorch_lightning/metrics/classification/helpers.py diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 6b5a03fcf1ea6..8ad3dd99240c1 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -3,7 +3,7 @@ from torch import randint, rand from pytorch_lightning.metrics.utils import to_onehot, select_topk -from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification from tests.metrics.classification.inputs import ( Input, _binary_inputs as _bin, From f3c47f980dbf00c5014c45a26e7a39233b039b03 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 13:47:44 +0100 Subject: [PATCH 38/38] Fix edge case and simplify testing --- .../metrics/classification/helpers.py | 4 +- tests/metrics/classification/test_inputs.py | 159 +++++++++--------- 2 files changed, 86 insertions(+), 77 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 7cfc505e6c673..afb97e6e0a74f 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -425,9 +425,9 @@ def _input_format_classification( preds = select_topk(preds, top_k) else: num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 - preds = to_onehot(preds, num_classes) + preds = to_onehot(preds, max(2,num_classes)) - target = to_onehot(target, num_classes) + target = to_onehot(target, max(2,num_classes)) if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8ad3dd99240c1..c4d01d282fa57 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -112,49 +112,50 @@ def _mlmd_prob_to_mc_preds_tr(x): @pytest.mark.parametrize( - "inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + "inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", [ ############################# # Test usual expected cases - (_bin, THRESHOLD, None, False, None, "multi-class", _usq, _usq), - (_bin_prob, THRESHOLD, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, THRESHOLD, None, None, None, "multi-label", _thrs, _idn), - (_ml, THRESHOLD, None, False, None, "multi-dim multi-class", _idn, _idn), - (_ml_prob, THRESHOLD, None, None, None, "multi-label", _ml_preds_tr, _rshp1), - (_mlmd, THRESHOLD, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), - (_mc, THRESHOLD, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), - (_mc_prob, THRESHOLD, None, None, None, "multi-class", _top1, _onehot), - (_mc_prob, THRESHOLD, None, None, 2, "multi-class", _top2, _onehot), - (_mdmc, THRESHOLD, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), - (_mdmc_prob, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), - (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), - (_mdmc_prob_many_dims, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), + (_bin, None, False, None, "multi-class", _usq, _usq), + (_bin, 1, False, None, "multi-class", _usq, _usq), + (_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, None, None, None, "multi-label", _thrs, _idn), + (_ml, None, False, None, "multi-dim multi-class", _idn, _idn), + (_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), + (_mc_prob, None, None, None, "multi-class", _top1, _onehot), + (_mc_prob, None, None, 2, "multi-class", _top2, _onehot), + (_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass - (_bin, THRESHOLD, None, None, None, "multi-class", _onehot2, _onehot2), + (_bin, None, None, None, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, THRESHOLD, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), + (_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), + (_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, THRESHOLD, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), + (_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), + (_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, THRESHOLD, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), + (_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, THRESHOLD, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), + (_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, THRESHOLD, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), ], ) -def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): +def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): preds_out, target_out, mode = _input_format_classification( preds=inputs.preds[0], target=inputs.target[0], - threshold=threshold, + threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k, @@ -168,7 +169,7 @@ def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_m preds_out, target_out, mode = _input_format_classification( preds=inputs.preds[0][[0], ...], target=inputs.target[0][[0], ...], - threshold=threshold, + threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k, @@ -194,92 +195,100 @@ def test_threshold(): ######################################################################## +def test_incorrect_threshold(): + with pytest.raises(ValueError): + _input_format_classification(preds=rand(size=(7,)), target=randint(high=2, size=(7,)), threshold=1.5) + + @pytest.mark.parametrize( - "preds, target, threshold, num_classes, is_multiclass, top_k", + "preds, target, num_classes, is_multiclass", [ # Target not integer - (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, None), + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), None, None), # Target negative - (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, None), + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), None, None), # Preds negative integers - (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), None, None), # Negative probabilities - (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), - # Threshold outside of [0,1] - (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, None), + (-rand(size=(7,)), randint(high=2, size=(7,)), None, None), # is_multiclass=False and target > 1 - (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, None), + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), None, False), # is_multiclass=False and preds integers with > 1 - (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, None), + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), None, False), # Wrong batch size - (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, None), + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), None, None), # Completely wrong shape - (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, None), + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), None, None), # Same #dims, different shape - (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, None), + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None), # Same shape and preds floats, target not binary - (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, None), + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), None, None), # #dims in preds = 1 + #dims in target, C shape not second or last - (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None), # #dims in preds = 1 + #dims in target, preds not float - (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None), # is_multiclass=False, with C dimension > 2 - (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), 0.5, None, False, None), + (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), None, False), # Probs of multiclass preds do not sum up to 1 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, None, None), + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None), # Max target larger or equal to C dimension - (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), 0.5, None, None, None), + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), None, None), # C dimension not equal to num_classes - (_mc_prob.preds[0], _mc_prob.target[0], 0.5, NUM_CLASSES + 1, None, None), + (_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None), # Max target larger than num_classes (with #dim preds = 1 + #dims target) - ( - _mc_prob.preds[0], - randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), - 0.5, - 4, - None, - None, - ), + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), 4, None), # Max target larger than num_classes (with #dim preds = #dims target) - (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None), # Max preds larger than num_classes (with #dim preds = #dims target) - (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, None), + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None), # Num_classes=1, but is_multiclass not false - (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, 1, None, 1), + (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 1, None), # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes - (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), # Multilabel input with implied class dimension != num_classes - (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) - (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, None), + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True), # Binary input, num_classes > 2 - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 4, None), # Binary input, num_classes == 2 and is_multiclass not True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, None), - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 2, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 2, False), # Binary input, num_classes == 1 and is_multiclass=True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 1, True), + ], +) +def test_incorrect_inputs(preds, target, num_classes, is_multiclass): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + + +@pytest.mark.parametrize( + "preds, target, num_classes, is_multiclass, top_k", + [ # Topk set with non (md)mc prob data - (_bin.preds[0], _bin.target[0], 0.5, None, None, 2), - (_bin_prob.preds[0], _bin_prob.target[0], 0.5, None, None, 2), - (_mc.preds[0], _mc.target[0], 0.5, None, None, 2), - (_ml.preds[0], _ml.target[0], 0.5, None, None, 2), - (_mlmd.preds[0], _mlmd.target[0], 0.5, None, None, 2), - (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), - (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), - (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), + (_bin.preds[0], _bin.target[0], None, None, 2), + (_bin_prob.preds[0], _bin_prob.target[0], None, None, 2), + (_mc.preds[0], _mc.target[0], None, None, 2), + (_ml.preds[0], _ml.target[0], None, None, 2), + (_mlmd.preds[0], _mlmd.target[0], None, None, 2), + (_ml_prob.preds[0], _ml_prob.target[0], None, None, 2), + (_mlmd_prob.preds[0], _mlmd_prob.target[0], None, None, 2), + (_mdmc.preds[0], _mdmc.target[0], None, None, 2), # top_k =2 with 2 classes, is_multiclass=False - (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], 0.5, None, False, 2), + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2), # top_k = number of classes (C dimension) - (_mc_prob.preds[0], _mc_prob.target[0], 0.5, None, None, NUM_CLASSES), + (_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES), ], ) -def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): +def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k): with pytest.raises(ValueError): _input_format_classification( preds=preds, target=target, - threshold=threshold, + threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k,