Skip to content

Commit

Permalink
Update metrics to use Enum (Lightning-AI#5657)
Browse files Browse the repository at this point in the history
- Add DataType, AverageMethod and MDMCAverageMethod
  • Loading branch information
yuntai committed Jan 31, 2021
1 parent 5d239cc commit 727e289
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 60 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- `DataType`, `AverageMethod` and `MDMCAverageMethod` enum ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)


- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))


Expand Down
71 changes: 52 additions & 19 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,39 @@
import torch

from pytorch_lightning.metrics.utils import select_topk, to_onehot
from pytorch_lightning.utilities import LightningEnum


class DataType(LightningEnum):
"""
Enum to represent data type
"""

BINARY = "binary"
MULTILABEL = "multi-label"
MULTICLASS = "multi-class"
MULTIDIM_MULTICLASS = "multi-dim multi-class"


class AverageMethod(LightningEnum):
"""
Enum to represent average method
"""

MICRO = "micro"
MACRO = "macro"
WEIGHTED = "weighted"
NONE = "none"
SAMPLES = "samples"


class MDMCAverageMethod(LightningEnum):
"""
Enum to represent multi-dim multi-class average method
"""

GLOBAL = "global"
SAMPLEWISE = "samplewise"


def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
Expand Down Expand Up @@ -78,13 +111,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)

# Get the case
if preds.ndim == 1 and preds_float:
case = "binary"
case = DataType.BINARY
elif preds.ndim == 1 and not preds_float:
case = "multi-class"
case = DataType.MULTICLASS
elif preds.ndim > 1 and preds_float:
case = "multi-label"
case = DataType.MULTILABEL
else:
case = "multi-dim multi-class"
case = DataType.MULTIDIM_MULTICLASS

implied_classes = preds[0].numel()

Expand All @@ -100,9 +133,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
implied_classes = preds.shape[1]

if preds.ndim == 2:
case = "multi-class"
case = DataType.MULTICLASS
else:
case = "multi-dim multi-class"
case = DataType.MULTIDIM_MULTICLASS
else:
raise ValueError(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
Expand Down Expand Up @@ -182,15 +215,15 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes


def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
if case == "binary":
if case == DataType.BINARY:
raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("The `top_k` has to be an integer larger than 0.")
if not preds_float:
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
if is_multiclass is False:
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
if case == "multi-label" and is_multiclass:
if case == DataType.MULTILABEL and is_multiclass:
raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
Expand Down Expand Up @@ -266,7 +299,7 @@ def _check_classification_inputs(
case, implied_classes = _check_shape_and_type_consistency(preds, target)

# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
if "multi-class" in case and preds.is_floating_point():
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")

Expand All @@ -284,11 +317,11 @@ def _check_classification_inputs(

# Check that num_classes is consistent
if num_classes:
if case == "binary":
if case == DataType.BINARY:
_check_num_classes_binary(num_classes, is_multiclass)
elif "multi-class" in case:
elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS):
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
elif case == "multi-label":
elif case.MULTILABEL:
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)

# Check that top_k is consistent
Expand Down Expand Up @@ -406,14 +439,14 @@ def _input_format_classification(
top_k=top_k,
)

if case in ["binary", "multi-label"] and not top_k:
if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
preds = (preds >= threshold).int()
num_classes = num_classes if not is_multiclass else 2

if case == "multi-label" and top_k:
if case == DataType.MULTILABEL and top_k:
preds = select_topk(preds, top_k)

if "multi-class" in case or is_multiclass:
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass:
if preds.is_floating_point():
num_classes = preds.shape[1]
preds = select_topk(preds, top_k or 1)
Expand All @@ -426,7 +459,7 @@ def _input_format_classification(
if is_multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...]

if ("multi-class" in case and is_multiclass is not False) or is_multiclass:
if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass:
target = target.reshape(target.shape[0], target.shape[1], -1)
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
else:
Expand Down Expand Up @@ -486,19 +519,19 @@ def _reduce_stat_scores(
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)

if average not in ["micro", "none", None]:
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
weights = weights / weights.sum(dim=-1, keepdim=True)

scores = weights * (numerator / denominator)

# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)

if mdmc_average == "samplewise":
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()

if average in ["none", None]:
if average in (AverageMethod.NONE, None):
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
else:
scores = scores.sum()
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType


def _accuracy_update(
Expand All @@ -24,19 +24,19 @@ def _accuracy_update(

preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)

if mode == "multi-label" and top_k:
if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")

if mode == "binary" or (mode == "multi-label" and subset_accuracy):
if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
correct = (preds == target).all(dim=1).sum()
total = torch.tensor(target.shape[0], device=target.device)
elif mode == "multi-label" and not subset_accuracy:
elif mode == DataType.MULTILABEL and not subset_accuracy:
correct = (preds == target).sum()
total = torch.tensor(target.numel(), device=target.device)
elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy):
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
correct = (preds * target).sum()
total = target.sum()
elif mode == "multi-dim multi-class" and subset_accuracy:
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
sample_correct = (preds * target).sum(dim=(1, 2))
correct = (sample_correct == target.shape[2]).sum()
total = torch.tensor(target.shape[0], device=target.device)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum
Expand Down Expand Up @@ -96,7 +96,7 @@ def _auroc_compute(
elif average == AverageMethods.MACRO:
return torch.mean(torch.stack(auc_scores))
elif average == AverageMethods.WEIGHTED:
if mode == 'multi-label':
if mode == DataType.MULTILABEL:
support = torch.sum(target, dim=0)
else:
support = torch.bincount(target.flatten(), minlength=num_classes)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.utilities import rank_zero_warn


Expand All @@ -24,7 +24,7 @@ def _confusion_matrix_update(preds: torch.Tensor,
num_classes: int,
threshold: float = 0.5) -> torch.Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in ('binary', 'multi-label'):
if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1)
target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
Expand Down
8 changes: 4 additions & 4 deletions tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.metrics import accuracy_score as sk_accuracy

from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional import accuracy
from tests.metrics.classification.inputs import (
_binary_inputs,
Expand All @@ -29,12 +29,12 @@ def _sk_accuracy(preds, target, subset_accuracy):
sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

if mode == "multi-dim multi-class" and not subset_accuracy:
if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy:
sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
elif mode == mode == "multi-dim multi-class" and subset_accuracy:
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
elif mode == "multi-label" and not subset_accuracy:
elif mode == DataType.MULTILABEL and not subset_accuracy:
sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)

return sk_accuracy(y_true=sk_target, y_pred=sk_preds)
Expand Down
58 changes: 31 additions & 27 deletions tests/metrics/classification/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import rand, randint

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.utils import select_topk, to_onehot
from tests.metrics.classification.inputs import _binary_inputs as _bin
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
Expand Down Expand Up @@ -155,32 +155,36 @@ def _mlmd_prob_to_mc_preds_tr(x):
],
)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0],
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
assert torch.equal(target_out, post_target(inputs.target[0]).int())

# Test that things work when batch_size = 1
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0][[0], ...],
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
def __get_data_type_enum(str_exp_mode):
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)

for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0],
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
assert torch.equal(target_out, post_target(inputs.target[0]).int())

# Test that things work when batch_size = 1
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0][[0], ...],
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())


# Test that threshold is correctly applied
Expand Down

0 comments on commit 727e289

Please sign in to comment.