Skip to content

Commit

Permalink
Update metrics to use Enum
Browse files Browse the repository at this point in the history
- change data type string to DataType enum
- add AverageMethod enum
- add enum equality test
- update test_inputs.py to use the defined enum type
  • Loading branch information
yuntai committed Jan 28, 2021
1 parent 817a41d commit 211dfea
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 62 deletions.
72 changes: 48 additions & 24 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,30 @@
import torch

from pytorch_lightning.metrics.utils import select_topk, to_onehot
from pytorch_lightning.utilities import LightningEnum


class DataType(LightningEnum):
"""
Enum to represent data type
"""

BINARY = "binary"
MULTILABEL = "multi-label"
MULTICLASS = "multi-class"
MDMC = "multi-dim multi-class"


class AverageMethod(LightningEnum):
"""
Enum to represent average method
"""

MICRO = "micro"
MACRO = "macro"
WEIGHTED = "weighted"
NONE = "one"
SAMPLES = "samples"


def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
Expand Down Expand Up @@ -50,15 +74,15 @@ def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold
raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.")


def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]:
def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[DataType, int]:
"""
This checks that the shape and type of inputs are consistent with
each other and fall into one of the allowed input types (see the
documentation of docstring of ``_input_format_classification``). It does
not check for consistency of number of classes, other functions take
care of that.
It returns the name of the case in which the inputs fall, and the implied
It returns a DataType enum in which the inputs fall, and the implied
number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for
multi-label data).
"""
Expand All @@ -78,13 +102,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)

# Get the case
if preds.ndim == 1 and preds_float:
case = "binary"
case = DataType.BINARY
elif preds.ndim == 1 and not preds_float:
case = "multi-class"
case = DataType.MULTICLASS
elif preds.ndim > 1 and preds_float:
case = "multi-label"
case = DataType.MULTILABEL
else:
case = "multi-dim multi-class"
case = DataType.MDMC

implied_classes = preds[0].numel()

Expand All @@ -100,9 +124,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
implied_classes = preds.shape[1]

if preds.ndim == 2:
case = "multi-class"
case = DataType.MULTICLASS
else:
case = "multi-dim multi-class"
case = DataType.MDMC
else:
raise ValueError(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
Expand Down Expand Up @@ -182,15 +206,15 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes


def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
if case == "binary":
if case == DataType.BINARY:
raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("The `top_k` has to be an integer larger than 0.")
if not preds_float:
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
if is_multiclass is False:
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
if case == "multi-label" and is_multiclass:
if case == DataType.MULTICLASS and is_multiclass:
raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
Expand All @@ -206,7 +230,7 @@ def _check_classification_inputs(
num_classes: Optional[int],
is_multiclass: bool,
top_k: Optional[int],
) -> str:
) -> DataType:
"""Performs error checking on inputs for classification.
This ensures that preds and target take one of the shape/type combinations that are
Expand Down Expand Up @@ -255,8 +279,8 @@ def _check_classification_inputs(
Return:
case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or
'multi-dim multi-class'
case: The case the inputs fall in, one of DataType.BINARY, DataType.MULTICLASS, DataType.MULTILABEL or
DataType.MDMC
"""

# Baisc validation (that does not need case/type information)
Expand All @@ -266,7 +290,7 @@ def _check_classification_inputs(
case, implied_classes = _check_shape_and_type_consistency(preds, target)

# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
if "multi-class" in case and preds.is_floating_point():
if case in (DataType.MULTICLASS, DataType.MDMC) and preds.is_floating_point():
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")

Expand All @@ -284,11 +308,11 @@ def _check_classification_inputs(

# Check that num_classes is consistent
if num_classes:
if case == "binary":
if case == DataType.BINARY:
_check_num_classes_binary(num_classes, is_multiclass)
elif "multi-class" in case:
elif case in (DataType.MULTICLASS, DataType.MDMC):
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
elif case == "multi-label":
elif case.MULTILABEL:
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)

# Check that top_k is consistent
Expand All @@ -305,7 +329,7 @@ def _input_format_classification(
top_k: Optional[int] = None,
num_classes: Optional[int] = None,
is_multiclass: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor, str]:
) -> Tuple[torch.Tensor, torch.Tensor, DataType]:
"""Convert preds and target tensors into common format.
Preds and targets are supposed to fall into one of these categories (and are
Expand Down Expand Up @@ -383,8 +407,8 @@ def _input_format_classification(
Returns:
preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)``
target: binary tensor of shape ``(N, C)`` or ``(N, C, X)``
case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or
``'multi-dim multi-class'``
case: The case the inputs fall in, one of DataType.BINARY, DataType.MULTICLASS, DataType.MULTILABEL or
DataType.MDMC
"""
# Remove excess dimensions
if preds.shape[0] == 1:
Expand All @@ -406,14 +430,14 @@ def _input_format_classification(
top_k=top_k,
)

if case in ["binary", "multi-label"] and not top_k:
if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
preds = (preds >= threshold).int()
num_classes = num_classes if not is_multiclass else 2

if case == "multi-label" and top_k:
if case == DataType.MULTICLASS and top_k:
preds = select_topk(preds, top_k)

if "multi-class" in case or is_multiclass:
if case in (DataType.MULTICLASS, DataType.MDMC) or is_multiclass:
if preds.is_floating_point():
num_classes = preds.shape[1]
preds = select_topk(preds, top_k or 1)
Expand All @@ -426,7 +450,7 @@ def _input_format_classification(
if is_multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...]

if ("multi-class" in case and is_multiclass is not False) or is_multiclass:
if (case in (DataType.MULTICLASS, DataType.MDMC) and is_multiclass is not False) or is_multiclass:
target = target.reshape(target.shape[0], target.shape[1], -1)
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
else:
Expand Down
20 changes: 20 additions & 0 deletions tests/metrics/classification/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from pytorch_lightning.metrics.classification.helpers import DataType


@pytest.mark.parametrize(
"s, data_type",
[
("binary", DataType.BINARY),
("multilabel", DataType.MULTILABEL),
("multiclass", DataType.MULTICLASS),
("mdmc", DataType.MDMC),
],
)
def test_data_type_equality(s, data_type):
assert s == data_type
assert data_type == s
assert s.upper() == data_type
assert data_type == s.upper()
assert DataType.from_str(s) == data_type
assert DataType.from_str(s.upper()) == data_type
77 changes: 39 additions & 38 deletions tests/metrics/classification/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
import torch
from torch import rand, randint

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.utils import select_topk, to_onehot
from tests.metrics.classification.inputs import _binary_inputs as _bin
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
from tests.metrics.classification.inputs import _multiclass_inputs as _mc
from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob
from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc
from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob
from tests.metrics.classification.inputs import _multilabel_inputs as _ml
from tests.metrics.classification.inputs import _multilabel_multidim_inputs as _mlmd
from tests.metrics.classification.inputs import _multilabel_multidim_prob_inputs as _mlmd_prob
from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob
from tests.metrics.classification.inputs import Input
from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD
from pytorch_lightning.metrics.utils import to_onehot, select_topk
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from tests.metrics.classification.inputs import (
Input,
_binary_inputs as _bin,
_binary_prob_inputs as _bin_prob,
_multiclass_inputs as _mc,
_multiclass_prob_inputs as _mc_prob,
_multidim_multiclass_inputs as _mdmc,
_multidim_multiclass_prob_inputs as _mdmc_prob,
_multilabel_inputs as _ml,
_multilabel_prob_inputs as _ml_prob,
_multilabel_multidim_inputs as _mlmd,
_multilabel_multidim_prob_inputs as _mlmd_prob,
)
from tests.metrics.utils import NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, THRESHOLD

torch.manual_seed(42)

Expand Down Expand Up @@ -116,42 +118,41 @@ def _mlmd_prob_to_mc_preds_tr(x):
[
#############################
# Test usual expected cases
(_bin, None, False, None, "multi-class", _usq, _usq),
(_bin, 1, False, None, "multi-class", _usq, _usq),
(_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq),
(_ml_prob, None, None, None, "multi-label", _thrs, _idn),
(_ml, None, False, None, "multi-dim multi-class", _idn, _idn),
(_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1),
(_ml_prob, None, None, 2, "multi-label", _top2, _rshp1),
(_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1),
(_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot),
(_mc_prob, None, None, None, "multi-class", _top1, _onehot),
(_mc_prob, None, None, 2, "multi-class", _top2, _onehot),
(_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot),
(_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot),
(_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot),
(_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1),
(_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1),
(_bin, None, False, None, DataType.MULTICLASS, _usq, _usq),
(_bin, 1, False, None, DataType.MULTICLASS, _usq, _usq),
(_bin_prob, None, None, None, DataType.BINARY, lambda x: _usq(_thrs(x)), _usq),
(_ml_prob, None, None, None, DataType.MULTILABEL, _thrs, _idn),
(_ml, None, False, None, DataType.MDMC, _idn, _idn),
(_ml_prob, None, None, None, DataType.MULTILABEL, _ml_preds_tr, _rshp1),
(_mlmd, None, False, None, DataType.MDMC, _rshp1, _rshp1),
(_mc, NUM_CLASSES, None, None, DataType.MULTICLASS, _onehot, _onehot),
(_mc_prob, None, None, None, DataType.MULTICLASS, _top1, _onehot),
(_mc_prob, None, None, 2, DataType.MULTICLASS, _top2, _onehot),
(_mdmc, NUM_CLASSES, None, None, DataType.MDMC, _onehot, _onehot),
(_mdmc_prob, None, None, None, DataType.MDMC, _top1_rshp2, _onehot),
(_mdmc_prob, None, None, 2, DataType.MDMC, _top2_rshp2, _onehot),
(_mdmc_prob_many_dims, None, None, None, DataType.MDMC, _top1_rshp2, _onehot_rshp1),
(_mdmc_prob_many_dims, None, None, 2, DataType.MDMC, _top2_rshp2, _onehot_rshp1),
###########################
# Test some special cases
# Make sure that half precision works, i.e. is converted to full precision
(_ml_prob_half, None, None, None, "multi-label", lambda x: _ml_preds_tr(x.float()), _rshp1),
# Binary as multiclass
(_bin, None, None, None, "multi-class", _onehot2, _onehot2),
(_bin, None, None, None, DataType.MULTICLASS, _onehot2, _onehot2),
# Binary probs as multiclass
(_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2),
(_bin_prob, None, True, None, DataType.BINARY, _probs_to_mc_preds_tr, _onehot2),
# Multilabel as multiclass
(_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2),
(_ml, None, True, None, DataType.MDMC, _onehot2, _onehot2),
# Multilabel probs as multiclass
(_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2),
(_ml_prob, None, True, None, DataType.MULTILABEL, _probs_to_mc_preds_tr, _onehot2),
# Multidim multilabel as multiclass
(_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1),
(_mlmd, None, True, None, DataType.MDMC, _onehot2_rshp1, _onehot2_rshp1),
# Multidim multilabel probs as multiclass
(_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1),
(_mlmd_prob, None, True, None, DataType.MULTILABEL, _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1),
# Multiclass prob with 2 classes as binary
(_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq),
(_mc_prob_2cls, None, False, None, DataType.MULTICLASS, lambda x: _top1(x)[:, [1]], _usq),
# Multi-dim multi-class with 2 classes as multi-label
(_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn),
(_mdmc_prob_2cls, None, False, None, DataType.MDMC, lambda x: _top1(x)[:, 1], _idn),
],
)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
Expand Down

0 comments on commit 211dfea

Please sign in to comment.