diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 6cc5a9387a8ef..ee141dc74a679 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -196,53 +196,76 @@ Metric API .. autoclass:: pytorch_lightning.metrics.Metric :noindex: -************* -Class metrics -************* +*************************** +Class vs Functional Metrics +*************************** +The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. + +Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. +If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface. + +********************** Classification Metrics ----------------------- +********************** -Accuracy -~~~~~~~~ +Input types +----------- -.. autoclass:: pytorch_lightning.metrics.classification.Accuracy - :noindex: +For the purposes of classification metrics, inputs (predictions and targets) are split +into these categories (``N`` stands for the batch size and ``C`` for number of classes): -Precision -~~~~~~~~~ +.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 + :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype" + :widths: 20, 10, 10, 10, 10 -.. autoclass:: pytorch_lightning.metrics.classification.Precision - :noindex: + "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" + "Multi-class", "(N,)", "``int``", "(N,)", "``int``" + "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" + "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" + "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" -Recall -~~~~~~ +.. note:: + All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so + that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. -.. autoclass:: pytorch_lightning.metrics.classification.Recall - :noindex: +When predictions or targets are integers, it is assumed that class labels start at 0, i.e. +the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types -FBeta -~~~~~ +.. testcode:: -.. autoclass:: pytorch_lightning.metrics.classification.FBeta - :noindex: + # Binary inputs + binary_preds = torch.tensor([0.6, 0.1, 0.9]) + binary_target = torch.tensor([1, 0, 2]) -F1 -~~ + # Multi-class inputs + mc_preds = torch.tensor([0, 2, 1]) + mc_target = torch.tensor([0, 1, 2]) -.. autoclass:: pytorch_lightning.metrics.classification.F1 - :noindex: + # Multi-class inputs with probabilities + mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) + mc_target_probs = torch.tensor([0, 1, 2]) -ConfusionMatrix -~~~~~~~~~~~~~~~ + # Multi-label inputs + ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) + ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix - :noindex: +In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class +but are actually binary/multi-label. For example, if both predictions and targets are 1d +binary tensors. Or it could be the other way around, you want to treat binary/multi-label +inputs as 2-class (multi-dimensional) multi-class inputs. -PrecisionRecallCurve -~~~~~~~~~~~~~~~~~~~~ +For these cases, the metrics where this distinction would make a difference, expose the +``is_multiclass`` argument. -.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve +Class Metrics (Classification) +------------------------------ + +Accuracy +~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.Accuracy :noindex: AveragePrecision @@ -251,67 +274,51 @@ AveragePrecision .. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision :noindex: -ROC -~~~ +ConfusionMatrix +~~~~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.classification.ROC +.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix :noindex: -Regression Metrics ------------------- - -MeanSquaredError -~~~~~~~~~~~~~~~~ +F1 +~~ -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError +.. autoclass:: pytorch_lightning.metrics.classification.F1 :noindex: +FBeta +~~~~~ -MeanAbsoluteError -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError +.. autoclass:: pytorch_lightning.metrics.classification.FBeta :noindex: +Precision +~~~~~~~~~ -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError +.. autoclass:: pytorch_lightning.metrics.classification.Precision :noindex: +PrecisionRecallCurve +~~~~~~~~~~~~~~~~~~~~ -ExplainedVariance -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance +.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve :noindex: +Recall +~~~~~~ -PSNR -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.PSNR +.. autoclass:: pytorch_lightning.metrics.classification.Recall :noindex: +ROC +~~~ -SSIM -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.SSIM +.. autoclass:: pytorch_lightning.metrics.classification.ROC :noindex: -****************** -Functional Metrics -****************** - -The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. - -Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. -Classification --------------- +Functional Metrics (Classification) +----------------------------------- accuracy [func] ~~~~~~~~~~~~~~~ @@ -417,6 +424,12 @@ recall [func] .. autofunction:: pytorch_lightning.metrics.functional.classification.recall :noindex: +select_topk [func] +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.utils.select_topk + :noindex: + stat_scores [func] ~~~~~~~~~~~~~~~~~~ @@ -445,9 +458,57 @@ to_onehot [func] .. autofunction:: pytorch_lightning.metrics.utils.to_onehot :noindex: +****************** +Regression Metrics +****************** + +Class Metrics (Regression) +-------------------------- -Regression ----------- +ExplainedVariance +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance + :noindex: + + +MeanAbsoluteError +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError + :noindex: + + +MeanSquaredError +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError + :noindex: + + +MeanSquaredLogError +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError + :noindex: + + +PSNR +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.PSNR + :noindex: + + +SSIM +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.SSIM + :noindex: + + +Functional Metrics (Regression) +------------------------------- explained_variance [func] ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -470,17 +531,17 @@ mean_squared_error [func] :noindex: -psnr [func] -~~~~~~~~~~~ +mean_squared_log_error [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.psnr +.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error :noindex: -mean_squared_log_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +psnr [func] +~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error +.. autofunction:: pytorch_lightning.metrics.functional.psnr :noindex: @@ -490,22 +551,22 @@ ssim [func] .. autofunction:: pytorch_lightning.metrics.functional.ssim :noindex: - +*** NLP ---- +*** bleu_score [func] -~~~~~~~~~~~~~~~~~ +----------------- .. autofunction:: pytorch_lightning.metrics.functional.nlp.bleu_score :noindex: - +******** Pairwise --------- +******** embedding_similarity [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +--------------------------- .. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity :noindex: diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py new file mode 100644 index 0000000000000..afb97e6e0a74f --- /dev/null +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -0,0 +1,446 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Optional + +import torch + +from pytorch_lightning.metrics.utils import to_onehot, select_topk + + +def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): + """ + Perform basic validation of inputs that does not require deducing any information + of the type of inputs. + """ + + if target.is_floating_point(): + raise ValueError("The `target` has to be an integer tensor.") + if target.min() < 0: + raise ValueError("The `target` has to be a non-negative tensor.") + + preds_float = preds.is_floating_point() + if not preds_float and preds.min() < 0: + raise ValueError("If `preds` are integers, they have to be non-negative.") + + if not preds.shape[0] == target.shape[0]: + raise ValueError("The `preds` and `target` should have the same first dimension.") + + if preds_float and (preds.min() < 0 or preds.max() > 1): + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") + + if threshold > 1 or threshold < 0: + raise ValueError("The `threshold` should be a probability in [0,1].") + + if is_multiclass is False and target.max() > 1: + raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") + + if is_multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") + + +def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: + """ + This checks that the shape and type of inputs are consistent with + each other and fall into one of the allowed input types (see the + documentation of docstring of ``_input_format_classification``). It does + not check for consistency of number of classes, other functions take + care of that. + + It returns the name of the case in which the inputs fall, and the implied + number of classes (from the C dim for multi-class data, or extra dim(s) for + multi-label data). + """ + + preds_float = preds.is_floating_point() + + if preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + if preds_float and target.max() > 1: + raise ValueError( + "If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + ) + + # Get the case + if preds.ndim == 1 and preds_float: + case = "binary" + elif preds.ndim == 1 and not preds_float: + case = "multi-class" + elif preds.ndim > 1 and preds_float: + case = "multi-label" + else: + case = "multi-dim multi-class" + + implied_classes = preds[0].numel() + + elif preds.ndim == target.ndim + 1: + if not preds_float: + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + + implied_classes = preds.shape[1] + + if preds.ndim == 2: + case = "multi-class" + else: + case = "multi-dim multi-class" + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + return case, implied_classes + + +def _check_num_classes_binary(num_classes: int, is_multiclass: bool): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for binary data. + """ + + if num_classes > 2: + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") + if num_classes == 2 and not is_multiclass: + raise ValueError( + "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." + ) + if num_classes == 1 and is_multiclass: + raise ValueError( + "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." + " Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format." + ) + + +def _check_num_classes_mc( + preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int +): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for (multi-dimensional) multi-class data. + """ + + if num_classes == 1 and is_multiclass is not False: + raise ValueError( + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `is_multiclass=False`." + ) + if num_classes > 1: + if is_multiclass is False: + if implied_classes != num_classes: + raise ValueError( + "You have set `is_multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." + ) + if num_classes <= target.max(): + raise ValueError("The highest label in `target` should be smaller than `num_classes`.") + if num_classes <= preds.max(): + raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") + if preds.shape != target.shape and num_classes != implied_classes: + raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") + + +def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for multi-label data. + """ + + if is_multiclass and num_classes != 2: + raise ValueError( + "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." + ) + if not is_multiclass and num_classes != implied_classes: + raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + + +def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): + if "multi-class" not in case or not preds_float: + raise ValueError( + "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" + " with probability predictions." + ) + if is_multiclass is False: + raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") + if top_k >= implied_classes: + raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") + + +def _check_classification_inputs( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float, + num_classes: Optional[int], + is_multiclass: bool, + top_k: Optional[int], +) -> str: + """Performs error checking on inputs for classification. + + This ensures that preds and target take one of the shape/type combinations that are + specified in ``_input_format_classification`` docstring. It also checks the cases of + over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class + cases) that there are only up to 2 distinct labels. + + In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. + + When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, + multi-label, ...), and that, if availible, the implied number of classes in the ``C`` + dimension is consistent with it (as well as that max label in target is smaller than it). + + When ``num_classes`` is not specified in these cases, consistency of the highest target + value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. + + If ``top_k`` is set (not None) for inputs which are not (multi-dimensional) multi class + with probabilities, then an error is raised. Similarly if ``top_k`` is set to a number + that is higher than or equal to the ``C`` dimension of ``preds``. + + Preds and target tensors are expected to be squeezed already - all dimensions should be + greater than 1, except perhaps the first one (N). + + Args: + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as 1 for these inputs. + + Should be left unset (``None``) for all other types of inputs. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be (see :ref:`metrics: Input types` documentation section for + input classification and examples of the use of this parameter). Should be left at default + value (``None``) in most cases. + + The special cases where this parameter should be set are: + + - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional + multi-class with 2 classes, respectively. The probabilities are interpreted as the + probability of the "1" class, and thresholding still applies as usual. In this case + the parameter should be set to ``True``. + - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes + as binary or multi-label inputs, respectively. This is mainly meant for the case when + inputs are labels, but will work if they are probabilities as well. For this case the + parameter should be set to ``False``. + + Return: + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' + """ + + # Baisc validation (that does not need case/type information) + _basic_input_validation(preds, target, threshold, is_multiclass) + + # Check that shape/types fall into one of the cases + case, implied_classes = _check_shape_and_type_consistency(preds, target) + + # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 + if "multi-class" in case and preds.is_floating_point(): + if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): + raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") + + # Check consistency with the `C` dimension in case of multi-class data + if preds.shape != target.shape: + if is_multiclass is False and implied_classes != 2: + raise ValueError( + "You have set `is_multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." + ) + if target.max() >= implied_classes: + raise ValueError( + "The highest label in `target` should be smaller than the size of the `C` dimension of `preds`." + ) + + # Check that num_classes is consistent + if num_classes: + if case == "binary": + _check_num_classes_binary(num_classes, is_multiclass) + elif "multi-class" in case: + _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) + elif case == "multi-label": + _check_num_classes_ml(num_classes, is_multiclass, implied_classes) + + # Check that top_k is consistent + if top_k: + _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) + + return case + + +def _input_format_classification( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + top_k: Optional[int] = None, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor, str]: + """Convert preds and target tensors into common format. + + Preds and targets are supposed to fall into one of these categories (and are + validated to make sure this is the case): + + * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) + * Both preds and target are of shape ``(N,)``, and target is binary, while preds + are a float (binary) + * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and + is integer (multi-class) + * preds and target are of shape ``(N, ...)``, target is binary and preds is a float + (multi-label) + * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` + and is integer (multi-dimensional multi-class) + * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional + multi-class) + + To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. + + The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` + of ``(N, C, X)``, the details for each case are described below. The function also returns + a ``case`` string, which describes which of the above cases the inputs belonged to - regardless + of whether this was "overridden" by other settings (like ``is_multiclass``). + + In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed + into a binary tensor (elements become 1 if the probability is greater than or equal to + ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds + become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to + preds first. + + In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets + by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original + shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be + returned as ``(N,1)`` tensor. + + In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with + preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening + all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as + ``(N, 2, C)``, by an equivalent transformation as in the binary case. + + In multi-dimensional multi-class case, normally both target and preds are returned as + ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and + ``C``. The transformations performed here are equivalent to the multi-class case. However, if + ``is_multiclass=False`` (and there are up to two classes), then the data is returned as + ``(N, X)`` binary tensors (multi-label). + + Note that where a one-hot transformation needs to be performed and the number of classes + is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be + equal to ``num_classes``, if it is given, or the maximum label value in preds and + target. + + Args: + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as 1 for these inputs. + + Should be left unset (``None``) for all other types of inputs. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be (see :ref:`metrics: Input types` documentation section for + input classification and examples of the use of this parameter). Should be left at default + value (``None``) in most cases. + + The special cases where this parameter should be set are: + + - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional + multi-class with 2 classes, respectively. The probabilities are interpreted as the + probability of the "1" class, and thresholding still applies as usual. In this case + the parameter should be set to ``True``. + - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes + as binary or multi-label inputs, respectively. This is mainly meant for the case when + inputs are labels, but will work if they are probabilities as well. For this case the + parameter should be set to ``False``. + + + Returns: + preds: binary tensor of shape (N, C) or (N, C, X) + target: binary tensor of shape (N, C) or (N, C, X) + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' + """ + # Remove excess dimensions + if preds.shape[0] == 1: + preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) + else: + preds, target = preds.squeeze(), target.squeeze() + + case = _check_classification_inputs( + preds, + target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + top_k = top_k if top_k else 1 + + if case in ["binary", "multi-label"]: + preds = (preds >= threshold).int() + num_classes = num_classes if not is_multiclass else 2 + + if "multi-class" in case or is_multiclass: + if preds.is_floating_point(): + num_classes = preds.shape[1] + preds = select_topk(preds, top_k) + else: + num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 + preds = to_onehot(preds, max(2,num_classes)) + + target = to_onehot(target, max(2,num_classes)) + + if is_multiclass is False: + preds, target = preds[:, 1, ...], target[:, 1, ...] + + if ("multi-class" in case and is_multiclass is not False) or is_multiclass: + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + else: + target = target.reshape(target.shape[0], -1) + preds = preds.reshape(preds.shape[0], -1) + + # Some operatins above create an extra dimension for MC/binary case - this removes it + if preds.ndim > 2: + preds, target = preds.squeeze(-1), target.squeeze(-1) + + return preds.int(), target.int(), case diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 9aaa5578edb80..92faca200d0aa 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -21,6 +21,7 @@ def dim_zero_cat(x): + x = x if isinstance(x, (list, tuple)) else [x] return torch.cat(x, dim=0) @@ -39,15 +40,13 @@ def _flatten(x): def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): """ Check that predictions and target have the same shape, else raise error """ if pred.shape != target.shape: - raise RuntimeError('Predictions and targets are expected to have the same shape') + raise RuntimeError("Predictions and targets are expected to have the same shape") def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5 + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into label tensors + """Convert preds and target tensors into label tensors Args: preds: either tensor with labels, tensor with probabilities/logits or @@ -59,29 +58,23 @@ def _input_format_classification( preds: tensor with labels target: tensor with labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + if preds.ndim == target.ndim and preds.is_floating_point(): # binary or multilabel probablities preds = (preds >= threshold).long() return preds, target def _input_format_classification_one_hot( - num_classes: int, - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - multilabel: bool = False + num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, multilabel: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into one hot spare label tensors + """Convert preds and target tensors into one hot spare label tensors Args: num_classes: number of classes @@ -95,26 +88,24 @@ def _input_format_classification_one_hot( preds: one hot tensor of shape [num_classes, -1] with predicted labels target: one hot tensors of shape [num_classes, -1] with true labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.long and num_classes > 1 and not multilabel: + if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel: # multi-class preds = to_onehot(preds, num_classes=num_classes) target = to_onehot(target, num_classes=num_classes) - elif len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + elif preds.ndim == target.ndim and preds.is_floating_point(): # binary or multilabel probablities preds = (preds >= threshold).long() # transpose class as first dim and reshape - if len(preds.shape) > 1: + if preds.ndim > 1: preds = preds.transpose(1, 0) target = target.transpose(1, 0) @@ -122,14 +113,14 @@ def _input_format_classification_one_hot( def to_onehot( - tensor: torch.Tensor, - num_classes: Optional[int] = None, + label_tensor: torch.Tensor, + num_classes: Optional[int] = None, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] + label_tensor: dense label tensor, with shape [N, d1, d2, ...] num_classes: number of classes C Output: @@ -145,18 +136,45 @@ def to_onehot( """ if num_classes is None: - num_classes = int(tensor.max().detach().item() + 1) - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + num_classes = int(label_tensor.max().detach().item() + 1) + + tensor_onehot = torch.zeros( + label_tensor.shape[0], + num_classes, + *label_tensor.shape[1:], + dtype=label_tensor.dtype, + device=label_tensor.device, + ) + index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) -def to_categorical( - tensor: torch.Tensor, - argmax_dim: int = 1 -) -> torch.Tensor: +def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: + """ + Convert a probability tensor to binary by selecting top-k highest entries. + + Args: + prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + position defined by the ``dim`` argument + topk: number of highest entries to turn into 1s + dim: dimension on which to compare entries + + Output: + A binary tensor of the same shape as the input tensor of type torch.int32 + + Example: + >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> select_topk(x, topk=2) + tensor([[0, 1, 1], + [1, 1, 0]], dtype=torch.int32) + """ + zeros = torch.zeros_like(prob_tensor) + topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) + + return topk_tensor.int() + + +def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ Converts a tensor of probabilities to a dense label tensor @@ -178,9 +196,9 @@ def to_categorical( def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, ) -> int: """ Calculates the number of classes for a given prediction and target tensor. @@ -200,10 +218,12 @@ def get_num_classes( if num_classes is None: num_classes = num_all_classes elif num_classes != num_all_classes: - rank_zero_warn(f'You have set {num_classes} number of classes which is' - f' different from predicted ({num_pred_classes}) and' - f' target ({num_target_classes}) number of classes', - RuntimeWarning) + rank_zero_warn( + f"You have set {num_classes} number of classes which is" + f" different from predicted ({num_pred_classes}) and" + f" target ({num_target_classes}) number of classes", + RuntimeWarning, + ) return num_classes @@ -221,19 +241,18 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: Raise: ValueError if an invalid reduction parameter was given """ - if reduction == 'elementwise_mean': + if reduction == "elementwise_mean": return torch.mean(to_reduce) - if reduction == 'none': + if reduction == "none": return to_reduce - if reduction == 'sum': + if reduction == "sum": return torch.sum(to_reduce) - raise ValueError('Reduction parameter unknown.') + raise ValueError("Reduction parameter unknown.") -def class_reduce(num: torch.Tensor, - denom: torch.Tensor, - weights: torch.Tensor, - class_reduction: str = 'none') -> torch.Tensor: +def class_reduce( + num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" +) -> torch.Tensor: """ Function used to reduce classification metrics of the form `num / denom * weights`. For example for calculating standard accuracy the num would be number of @@ -252,8 +271,8 @@ def class_reduce(num: torch.Tensor, - ``'none'`` or ``None``: returns calculated metric per class """ - valid_reduction = ('micro', 'macro', 'weighted', 'none', None) - if class_reduction == 'micro': + valid_reduction = ("micro", "macro", "weighted", "none", None) + if class_reduction == "micro": fraction = torch.sum(num) / torch.sum(denom) else: fraction = num / denom @@ -262,14 +281,15 @@ def class_reduce(num: torch.Tensor, # for some (or all) classes which will produce nans fraction[fraction != fraction] = 0 - if class_reduction == 'micro': + if class_reduction == "micro": return fraction - elif class_reduction == 'macro': + elif class_reduction == "macro": return torch.mean(fraction) - elif class_reduction == 'weighted': + elif class_reduction == "weighted": return torch.sum(fraction * (weights.float() / torch.sum(weights))) - elif class_reduction == 'none' or class_reduction is None: + elif class_reduction == "none" or class_reduction is None: return fraction - raise ValueError(f'Reduction parameter {class_reduction} unknown.' - f' Choose between one of these: {valid_reduction}') + raise ValueError( + f"Reduction parameter {class_reduction} unknown." f" Choose between one of these: {valid_reduction}" + ) diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index 9613df3b6f8ca..9f70a80cd31a4 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -29,12 +29,21 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) _multilabel_inputs = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) @@ -44,8 +53,11 @@ target=__temp_target ) +__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True) + _multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) @@ -55,14 +67,15 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) +__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) +__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) _multidim_multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) - _multidim_multiclass_inputs = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)) + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py new file mode 100644 index 0000000000000..c4d01d282fa57 --- /dev/null +++ b/tests/metrics/classification/test_inputs.py @@ -0,0 +1,295 @@ +import pytest +import torch +from torch import randint, rand + +from pytorch_lightning.metrics.utils import to_onehot, select_topk +from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from tests.metrics.classification.inputs import ( + Input, + _binary_inputs as _bin, + _binary_prob_inputs as _bin_prob, + _multiclass_inputs as _mc, + _multiclass_prob_inputs as _mc_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, + _multilabel_inputs as _ml, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_inputs as _mlmd, + _multilabel_multidim_prob_inputs as _mlmd_prob, +) +from tests.metrics.utils import NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, THRESHOLD + +torch.manual_seed(42) + +# Some additional inputs to test on +_mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) +_mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) + +_mdmc_prob_many_dims_preds = rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM) +_mdmc_prob_many_dims_preds /= _mdmc_prob_many_dims_preds.sum(dim=2, keepdim=True) +_mdmc_prob_many_dims = Input( + _mdmc_prob_many_dims_preds, + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) + +_mdmc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM) +_mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))) + +# Some utils +T = torch.Tensor + + +def _idn(x): + return x + + +def _usq(x): + return x.unsqueeze(-1) + + +def _thrs(x): + return x >= THRESHOLD + + +def _rshp1(x): + return x.reshape(x.shape[0], -1) + + +def _rshp2(x): + return x.reshape(x.shape[0], x.shape[1], -1) + + +def _onehot(x): + return to_onehot(x, NUM_CLASSES) + + +def _onehot2(x): + return to_onehot(x, 2) + + +def _top1(x): + return select_topk(x, 1) + + +def _top2(x): + return select_topk(x, 2) + + +# To avoid ugly black line wrapping +def _ml_preds_tr(x): + return _rshp1(_thrs(x)) + + +def _onehot_rshp1(x): + return _onehot(_rshp1(x)) + + +def _onehot2_rshp1(x): + return _onehot2(_rshp1(x)) + + +def _top1_rshp2(x): + return _top1(_rshp2(x)) + + +def _top2_rshp2(x): + return _top2(_rshp2(x)) + + +def _probs_to_mc_preds_tr(x): + return _onehot2(_thrs(x)) + + +def _mlmd_prob_to_mc_preds_tr(x): + return _onehot2(_rshp1(_thrs(x))) + + +######################## +# Test correct inputs +######################## + + +@pytest.mark.parametrize( + "inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + [ + ############################# + # Test usual expected cases + (_bin, None, False, None, "multi-class", _usq, _usq), + (_bin, 1, False, None, "multi-class", _usq, _usq), + (_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, None, None, None, "multi-label", _thrs, _idn), + (_ml, None, False, None, "multi-dim multi-class", _idn, _idn), + (_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), + (_mc_prob, None, None, None, "multi-class", _top1, _onehot), + (_mc_prob, None, None, 2, "multi-class", _top2, _onehot), + (_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), + ########################### + # Test some special cases + # Binary as multiclass + (_bin, None, None, None, "multi-class", _onehot2, _onehot2), + # Binary probs as multiclass + (_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), + # Multilabel as multiclass + (_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), + # Multilabel probs as multiclass + (_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), + # Multidim multilabel as multiclass + (_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), + # Multidim multilabel probs as multiclass + (_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), + # Multiclass prob with 2 classes as binary + (_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), + # Multi-dim multi-class with 2 classes as multi-label + (_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + ], +) +def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0], + target=inputs.target[0], + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) + assert torch.equal(target_out, post_target(inputs.target[0]).int()) + + # Test that things work when batch_size = 1 + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0][[0], ...], + target=inputs.target[0][[0], ...], + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) + + +# Test that threshold is correctly applied +def test_threshold(): + target = T([1, 1, 1]).int() + preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5]) + + preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) + + assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int()) + + +######################################################################## +# Test incorrect inputs +######################################################################## + + +def test_incorrect_threshold(): + with pytest.raises(ValueError): + _input_format_classification(preds=rand(size=(7,)), target=randint(high=2, size=(7,)), threshold=1.5) + + +@pytest.mark.parametrize( + "preds, target, num_classes, is_multiclass", + [ + # Target not integer + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), None, None), + # Target negative + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), None, None), + # Preds negative integers + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), None, None), + # Negative probabilities + (-rand(size=(7,)), randint(high=2, size=(7,)), None, None), + # is_multiclass=False and target > 1 + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), None, False), + # is_multiclass=False and preds integers with > 1 + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), None, False), + # Wrong batch size + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), None, None), + # Completely wrong shape + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), None, None), + # Same #dims, different shape + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None), + # Same shape and preds floats, target not binary + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), None, None), + # #dims in preds = 1 + #dims in target, C shape not second or last + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None), + # #dims in preds = 1 + #dims in target, preds not float + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None), + # is_multiclass=False, with C dimension > 2 + (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), None, False), + # Probs of multiclass preds do not sum up to 1 + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None), + # Max target larger or equal to C dimension + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), None, None), + # C dimension not equal to num_classes + (_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None), + # Max target larger than num_classes (with #dim preds = 1 + #dims target) + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), 4, None), + # Max target larger than num_classes (with #dim preds = #dims target) + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None), + # Max preds larger than num_classes (with #dim preds = #dims target) + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None), + # Num_classes=1, but is_multiclass not false + (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 1, None), + # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), + # Multilabel input with implied class dimension != num_classes + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), + # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True), + # Binary input, num_classes > 2 + (rand(size=(7,)), randint(high=2, size=(7,)), 4, None), + # Binary input, num_classes == 2 and is_multiclass not True + (rand(size=(7,)), randint(high=2, size=(7,)), 2, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 2, False), + # Binary input, num_classes == 1 and is_multiclass=True + (rand(size=(7,)), randint(high=2, size=(7,)), 1, True), + ], +) +def test_incorrect_inputs(preds, target, num_classes, is_multiclass): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + + +@pytest.mark.parametrize( + "preds, target, num_classes, is_multiclass, top_k", + [ + # Topk set with non (md)mc prob data + (_bin.preds[0], _bin.target[0], None, None, 2), + (_bin_prob.preds[0], _bin_prob.target[0], None, None, 2), + (_mc.preds[0], _mc.target[0], None, None, 2), + (_ml.preds[0], _ml.target[0], None, None, 2), + (_mlmd.preds[0], _mlmd.target[0], None, None, 2), + (_ml_prob.preds[0], _ml_prob.target[0], None, None, 2), + (_mlmd_prob.preds[0], _mlmd_prob.target[0], None, None, 2), + (_mdmc.preds[0], _mdmc.target[0], None, None, 2), + # top_k =2 with 2 classes, is_multiclass=False + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2), + # top_k = number of classes (C dimension) + (_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES), + ], +) +def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, + target=target, + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 5c00384da1e14..c607a466b2068 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -21,10 +21,10 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -67,23 +67,23 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -127,17 +127,17 @@ def _functional_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = {}, - atol: float = 1e-8 + atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -150,22 +150,23 @@ def _functional_test( class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ + atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: - set_start_method('spawn') + set_start_method("spawn") except RuntimeError: pass self.poolSize = NUM_PROCESSES @@ -183,24 +184,26 @@ def run_functional_metric_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {} + metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ - _functional_test(preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=sk_metric, - metric_args=metric_args, - atol=self.atol) + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) def run_class_metric_test( self, @@ -214,22 +217,22 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32":