Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update metrics to use Enum (#5657) #5689

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)


- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))


Expand Down
71 changes: 52 additions & 19 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,39 @@
import torch

from pytorch_lightning.metrics.utils import select_topk, to_onehot
from pytorch_lightning.utilities import LightningEnum


class DataType(LightningEnum):
"""
Enum to represent data type
"""

BINARY = "binary"
MULTILABEL = "multi-label"
MULTICLASS = "multi-class"
MULTIDIM_MULTICLASS = "multi-dim multi-class"


class AverageMethod(LightningEnum):
"""
Enum to represent average method
"""

MICRO = "micro"
MACRO = "macro"
WEIGHTED = "weighted"
NONE = "none"
SAMPLES = "samples"


class MDMCAverageMethod(LightningEnum):
"""
Enum to represent multi-dim multi-class average method
"""

GLOBAL = "global"
SAMPLEWISE = "samplewise"


def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
Expand Down Expand Up @@ -78,13 +111,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)

# Get the case
if preds.ndim == 1 and preds_float:
case = "binary"
case = DataType.BINARY
elif preds.ndim == 1 and not preds_float:
case = "multi-class"
case = DataType.MULTICLASS
elif preds.ndim > 1 and preds_float:
case = "multi-label"
case = DataType.MULTILABEL
else:
case = "multi-dim multi-class"
case = DataType.MULTIDIM_MULTICLASS

implied_classes = preds[0].numel()

Expand All @@ -100,9 +133,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
implied_classes = preds.shape[1]

if preds.ndim == 2:
case = "multi-class"
case = DataType.MULTICLASS
else:
case = "multi-dim multi-class"
case = DataType.MULTIDIM_MULTICLASS
else:
raise ValueError(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
Expand Down Expand Up @@ -182,15 +215,15 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes


def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
if case == "binary":
if case == DataType.BINARY:
raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("The `top_k` has to be an integer larger than 0.")
if not preds_float:
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
if is_multiclass is False:
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
if case == "multi-label" and is_multiclass:
if case == DataType.MULTILABEL and is_multiclass:
raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
Expand Down Expand Up @@ -266,7 +299,7 @@ def _check_classification_inputs(
case, implied_classes = _check_shape_and_type_consistency(preds, target)

# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
if "multi-class" in case and preds.is_floating_point():
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")

Expand All @@ -284,11 +317,11 @@ def _check_classification_inputs(

# Check that num_classes is consistent
if num_classes:
if case == "binary":
if case == DataType.BINARY:
_check_num_classes_binary(num_classes, is_multiclass)
elif "multi-class" in case:
elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS):
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
elif case == "multi-label":
elif case.MULTILABEL:
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)

# Check that top_k is consistent
Expand Down Expand Up @@ -406,14 +439,14 @@ def _input_format_classification(
top_k=top_k,
)

if case in ["binary", "multi-label"] and not top_k:
if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
preds = (preds >= threshold).int()
num_classes = num_classes if not is_multiclass else 2

if case == "multi-label" and top_k:
if case == DataType.MULTILABEL and top_k:
preds = select_topk(preds, top_k)

if "multi-class" in case or is_multiclass:
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass:
if preds.is_floating_point():
num_classes = preds.shape[1]
preds = select_topk(preds, top_k or 1)
Expand All @@ -426,7 +459,7 @@ def _input_format_classification(
if is_multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...]

if ("multi-class" in case and is_multiclass is not False) or is_multiclass:
if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass:
target = target.reshape(target.shape[0], target.shape[1], -1)
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
else:
Expand Down Expand Up @@ -486,19 +519,19 @@ def _reduce_stat_scores(
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)

if average not in ["micro", "none", None]:
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
weights = weights / weights.sum(dim=-1, keepdim=True)

scores = weights * (numerator / denominator)

# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)

if mdmc_average == "samplewise":
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()

if average in ["none", None]:
if average in (AverageMethod.NONE, None):
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
else:
scores = scores.sum()
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType


def _accuracy_update(
Expand All @@ -24,19 +24,19 @@ def _accuracy_update(

preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)

if mode == "multi-label" and top_k:
if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")

if mode == "binary" or (mode == "multi-label" and subset_accuracy):
if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
correct = (preds == target).all(dim=1).sum()
total = torch.tensor(target.shape[0], device=target.device)
elif mode == "multi-label" and not subset_accuracy:
elif mode == DataType.MULTILABEL and not subset_accuracy:
correct = (preds == target).sum()
total = torch.tensor(target.numel(), device=target.device)
elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy):
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
correct = (preds * target).sum()
total = target.sum()
elif mode == "multi-dim multi-class" and subset_accuracy:
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
sample_correct = (preds * target).sum(dim=(1, 2))
correct = (sample_correct == target.shape[2]).sum()
total = torch.tensor(target.shape[0], device=target.device)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum
Expand Down Expand Up @@ -96,7 +96,7 @@ def _auroc_compute(
elif average == AverageMethods.MACRO:
return torch.mean(torch.stack(auc_scores))
elif average == AverageMethods.WEIGHTED:
if mode == 'multi-label':
if mode == DataType.MULTILABEL:
support = torch.sum(target, dim=0)
else:
support = torch.bincount(target.flatten(), minlength=num_classes)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.utilities import rank_zero_warn


Expand All @@ -24,7 +24,7 @@ def _confusion_matrix_update(preds: torch.Tensor,
num_classes: int,
threshold: float = 0.5) -> torch.Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in ('binary', 'multi-label'):
if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1)
target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
Expand Down
8 changes: 4 additions & 4 deletions tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.metrics import accuracy_score as sk_accuracy

from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional import accuracy
from tests.metrics.classification.inputs import (
_binary_inputs,
Expand All @@ -29,12 +29,12 @@ def _sk_accuracy(preds, target, subset_accuracy):
sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

if mode == "multi-dim multi-class" and not subset_accuracy:
if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy:
sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
elif mode == mode == "multi-dim multi-class" and subset_accuracy:
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
elif mode == "multi-label" and not subset_accuracy:
elif mode == DataType.MULTILABEL and not subset_accuracy:
sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)

return sk_accuracy(y_true=sk_target, y_pred=sk_preds)
Expand Down
58 changes: 31 additions & 27 deletions tests/metrics/classification/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import rand, randint

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.utils import select_topk, to_onehot
from tests.metrics.classification.inputs import _binary_inputs as _bin
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
Expand Down Expand Up @@ -155,32 +155,36 @@ def _mlmd_prob_to_mc_preds_tr(x):
],
)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0],
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
assert torch.equal(target_out, post_target(inputs.target[0]).int())

# Test that things work when batch_size = 1
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0][[0], ...],
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
def __get_data_type_enum(str_exp_mode):
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)

for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0],
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
assert torch.equal(target_out, post_target(inputs.target[0]).int())

# Test that things work when batch_size = 1
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0][[0], ...],
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)

assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())


# Test that threshold is correctly applied
Expand Down