From 0947c4483adbc1e600841e7c5f057f9df7d96698 Mon Sep 17 00:00:00 2001 From: Yuntai Kyong Date: Thu, 28 Jan 2021 09:50:43 +0100 Subject: [PATCH] Update metrics to use Enum - change data type string to DataType enum - add enum equality test - update test_inputs.py to use the defined enum type update doc --- .../metrics/classification/helpers.py | 60 +++++++++------ tests/metrics/classification/test_helpers.py | 20 +++++ tests/metrics/classification/test_inputs.py | 77 ++++++++++--------- 3 files changed, 95 insertions(+), 62 deletions(-) create mode 100644 tests/metrics/classification/test_helpers.py diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 64cd3389e83113..e77b251b355e36 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -17,6 +17,18 @@ import torch from pytorch_lightning.metrics.utils import select_topk, to_onehot +from pytorch_lightning.utilities import LightningEnum + + +class DataType(LightningEnum): + """ + Enum to represent data type + """ + + BINARY = "binary" + MULTILABEL = "multilabel" + MULTICLASS = "multiclass" + MDMC = "mdmc" # multi-dim multi-class def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): @@ -50,7 +62,7 @@ def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") -def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: +def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[DataType, int]: """ This checks that the shape and type of inputs are consistent with each other and fall into one of the allowed input types (see the @@ -58,7 +70,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) not check for consistency of number of classes, other functions take care of that. - It returns the name of the case in which the inputs fall, and the implied + It returns a DataType enum in which the inputs fall, and the implied number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for multi-label data). """ @@ -78,13 +90,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) # Get the case if preds.ndim == 1 and preds_float: - case = "binary" + case = DataType.BINARY elif preds.ndim == 1 and not preds_float: - case = "multi-class" + case = DataType.MULTICLASS elif preds.ndim > 1 and preds_float: - case = "multi-label" + case = DataType.MULTILABEL else: - case = "multi-dim multi-class" + case = DataType.MDMC implied_classes = preds[0].numel() @@ -100,9 +112,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) implied_classes = preds.shape[1] if preds.ndim == 2: - case = "multi-class" + case = DataType.MULTICLASS else: - case = "multi-dim multi-class" + case = DataType.MDMC else: raise ValueError( "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" @@ -182,7 +194,7 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): - if case == "binary": + if case == DataType.BINARY: raise ValueError("You can not use `top_k` parameter with binary data.") if not isinstance(top_k, int) or top_k <= 0: raise ValueError("The `top_k` has to be an integer larger than 0.") @@ -190,7 +202,7 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt raise ValueError("You have set `top_k`, but you do not have probability predictions.") if is_multiclass is False: raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") - if case == "multi-label" and is_multiclass: + if case == DataType.MULTICLASS and is_multiclass: raise ValueError( "If you want to transform multi-label data to 2 class multi-dimensional" "multi-class data using `is_multiclass=True`, you can not use `top_k`." @@ -206,7 +218,7 @@ def _check_classification_inputs( num_classes: Optional[int], is_multiclass: bool, top_k: Optional[int], -) -> str: +) -> DataType: """Performs error checking on inputs for classification. This ensures that preds and target take one of the shape/type combinations that are @@ -255,8 +267,8 @@ def _check_classification_inputs( Return: - case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or - 'multi-dim multi-class' + case: The case the inputs fall in, one of DataType.BINARY, DataType.MULTICLASS, DataType.MULTILABEL or + DataType.MDMC """ # Baisc validation (that does not need case/type information) @@ -266,7 +278,7 @@ def _check_classification_inputs( case, implied_classes = _check_shape_and_type_consistency(preds, target) # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 - if "multi-class" in case and preds.is_floating_point(): + if case in (DataType.MULTICLASS, DataType.MDMC) and preds.is_floating_point(): if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") @@ -284,11 +296,11 @@ def _check_classification_inputs( # Check that num_classes is consistent if num_classes: - if case == "binary": + if case == DataType.BINARY: _check_num_classes_binary(num_classes, is_multiclass) - elif "multi-class" in case: + elif case in (DataType.MULTICLASS, DataType.MDMC): _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) - elif case == "multi-label": + elif case.MULTILABEL: _check_num_classes_ml(num_classes, is_multiclass, implied_classes) # Check that top_k is consistent @@ -305,7 +317,7 @@ def _input_format_classification( top_k: Optional[int] = None, num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor, str]: +) -> Tuple[torch.Tensor, torch.Tensor, DataType]: """Convert preds and target tensors into common format. Preds and targets are supposed to fall into one of these categories (and are @@ -383,8 +395,8 @@ def _input_format_classification( Returns: preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` target: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` - case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or - ``'multi-dim multi-class'`` + case: The case the inputs fall in, one of DataType.BINARY, DataType.MULTICLASS, DataType.MULTILABEL or + DataType.MDMC """ # Remove excess dimensions if preds.shape[0] == 1: @@ -406,14 +418,14 @@ def _input_format_classification( top_k=top_k, ) - if case in ["binary", "multi-label"] and not top_k: + if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: preds = (preds >= threshold).int() num_classes = num_classes if not is_multiclass else 2 - if case == "multi-label" and top_k: + if case == DataType.MULTICLASS and top_k: preds = select_topk(preds, top_k) - if "multi-class" in case or is_multiclass: + if case in (DataType.MULTICLASS, DataType.MDMC) or is_multiclass: if preds.is_floating_point(): num_classes = preds.shape[1] preds = select_topk(preds, top_k or 1) @@ -426,7 +438,7 @@ def _input_format_classification( if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] - if ("multi-class" in case and is_multiclass is not False) or is_multiclass: + if (case in (DataType.MULTICLASS, DataType.MDMC) and is_multiclass is not False) or is_multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) else: diff --git a/tests/metrics/classification/test_helpers.py b/tests/metrics/classification/test_helpers.py new file mode 100644 index 00000000000000..3a94d8eb4da479 --- /dev/null +++ b/tests/metrics/classification/test_helpers.py @@ -0,0 +1,20 @@ +import pytest +from pytorch_lightning.metrics.classification.helpers import DataType + + +@pytest.mark.parametrize( + "s, data_type", + [ + ("binary", DataType.BINARY), + ("multilabel", DataType.MULTILABEL), + ("multiclass", DataType.MULTICLASS), + ("mdmc", DataType.MDMC), + ], +) +def test_data_type_equality(s, data_type): + assert s == data_type + assert data_type == s + assert s.upper() == data_type + assert data_type == s.upper() + assert DataType.from_str(s) == data_type + assert DataType.from_str(s.upper()) == data_type diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8c8c6b9033cc3e..fa190afa9000cd 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -2,20 +2,22 @@ import torch from torch import rand, randint -from pytorch_lightning.metrics.classification.helpers import _input_format_classification -from pytorch_lightning.metrics.utils import select_topk, to_onehot -from tests.metrics.classification.inputs import _binary_inputs as _bin -from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob -from tests.metrics.classification.inputs import _multiclass_inputs as _mc -from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob -from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc -from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob -from tests.metrics.classification.inputs import _multilabel_inputs as _ml -from tests.metrics.classification.inputs import _multilabel_multidim_inputs as _mlmd -from tests.metrics.classification.inputs import _multilabel_multidim_prob_inputs as _mlmd_prob -from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob -from tests.metrics.classification.inputs import Input -from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD +from pytorch_lightning.metrics.utils import to_onehot, select_topk +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from tests.metrics.classification.inputs import ( + Input, + _binary_inputs as _bin, + _binary_prob_inputs as _bin_prob, + _multiclass_inputs as _mc, + _multiclass_prob_inputs as _mc_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, + _multilabel_inputs as _ml, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_inputs as _mlmd, + _multilabel_multidim_prob_inputs as _mlmd_prob, +) +from tests.metrics.utils import NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, THRESHOLD torch.manual_seed(42) @@ -116,42 +118,41 @@ def _mlmd_prob_to_mc_preds_tr(x): [ ############################# # Test usual expected cases - (_bin, None, False, None, "multi-class", _usq, _usq), - (_bin, 1, False, None, "multi-class", _usq, _usq), - (_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, None, None, None, "multi-label", _thrs, _idn), - (_ml, None, False, None, "multi-dim multi-class", _idn, _idn), - (_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1), - (_ml_prob, None, None, 2, "multi-label", _top2, _rshp1), - (_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), - (_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), - (_mc_prob, None, None, None, "multi-class", _top1, _onehot), - (_mc_prob, None, None, 2, "multi-class", _top2, _onehot), - (_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), - (_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), - (_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), - (_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), - (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), + (_bin, None, False, None, DataType.MULTICLASS, _usq, _usq), + (_bin, 1, False, None, DataType.MULTICLASS, _usq, _usq), + (_bin_prob, None, None, None, DataType.BINARY, lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, None, None, None, DataType.MULTILABEL, _thrs, _idn), + (_ml, None, False, None, DataType.MDMC, _idn, _idn), + (_ml_prob, None, None, None, DataType.MULTILABEL, _ml_preds_tr, _rshp1), + (_mlmd, None, False, None, DataType.MDMC, _rshp1, _rshp1), + (_mc, NUM_CLASSES, None, None, DataType.MULTICLASS, _onehot, _onehot), + (_mc_prob, None, None, None, DataType.MULTICLASS, _top1, _onehot), + (_mc_prob, None, None, 2, DataType.MULTICLASS, _top2, _onehot), + (_mdmc, NUM_CLASSES, None, None, DataType.MDMC, _onehot, _onehot), + (_mdmc_prob, None, None, None, DataType.MDMC, _top1_rshp2, _onehot), + (_mdmc_prob, None, None, 2, DataType.MDMC, _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, None, None, None, DataType.MDMC, _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, None, None, 2, DataType.MDMC, _top2_rshp2, _onehot_rshp1), ########################### # Test some special cases # Make sure that half precision works, i.e. is converted to full precision (_ml_prob_half, None, None, None, "multi-label", lambda x: _ml_preds_tr(x.float()), _rshp1), # Binary as multiclass - (_bin, None, None, None, "multi-class", _onehot2, _onehot2), + (_bin, None, None, None, DataType.MULTICLASS, _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), + (_bin_prob, None, True, None, DataType.BINARY, _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), + (_ml, None, True, None, DataType.MDMC, _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), + (_ml_prob, None, True, None, DataType.MULTILABEL, _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), + (_mlmd, None, True, None, DataType.MDMC, _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), + (_mlmd_prob, None, True, None, DataType.MULTILABEL, _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), + (_mc_prob_2cls, None, False, None, DataType.MULTICLASS, lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls, None, False, None, DataType.MDMC, lambda x: _top1(x)[:, 1], _idn), ], ) def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):