Skip to content

Commit

Permalink
Add AverageMethod and MDMCAverageMethod Enum(Lightning-AI#5657)
Browse files Browse the repository at this point in the history
* remove AverageMethods enum from functional/auroc.py
  • Loading branch information
yuntai committed Jan 28, 2021
1 parent e1c8f0f commit ce1e528
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 33 deletions.
14 changes: 8 additions & 6 deletions pytorch_lightning/metrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.metrics.classification.helpers import AverageMethod


class AUROC(Metric):
Expand Down Expand Up @@ -47,10 +48,10 @@ class AUROC(Metric):
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
average:
- ``'macro'`` computes metric for each class and uniformly averages them
- ``'weighted'`` computes metric for each class and does a weighted-average,
- ``AverageMethod.MACRO`` computes metric for each class and uniformly averages them
- ``AverageMethod.WEIGHTED`` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
- ``None`` computes and returns the metric per class
- ``AverageMethod.NONE`` or ``None`` computes and returns the metric per class
max_fpr:
If not ``None``, calculates standardized partial AUC over the
range [0, max_fpr]. Should be a float between 0 and 1.
Expand Down Expand Up @@ -86,11 +87,12 @@ class AUROC(Metric):
tensor(0.7778)
"""

def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
average: Optional[AverageMethod] = AverageMethod.MACRO,
max_fpr: Optional[float] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
Expand All @@ -109,13 +111,13 @@ def __init__(
self.average = average
self.max_fpr = max_fpr

allowed_average = (None, 'macro', 'weighted')
allowed_average = (None, AverageMethod.MACRO, AverageMethod.WEIGHTED, AverageMethod.NONE)
if self.average not in allowed_average:
raise ValueError('Argument `average` expected to be one of the following:'
f' {allowed_average} but got {average}')

if self.max_fpr is not None:
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
if not isinstance(max_fpr, float) and 0 < max_fpr <= 1:
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")

if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
Expand Down
48 changes: 35 additions & 13 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ class DataType(LightningEnum):
MULTIDIM_MULTICLASS = "multidim_multiclass"


class AverageMethod(LightningEnum):
"""
Enum to represent average method
"""

MICRO = "micro"
MACRO = "macro"
WEIGHTED = "weighted"
NONE = "none"
SAMPLES = "samples"


class MDMCAverageMethod(LightningEnum):
"""
Enum to represent multi-dim multi-class average method
"""

GLOBAL = "global"
SAMPLEWISE = "samplewise"


def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
"""
Perform basic validation of inputs that does not require deducing any information
Expand Down Expand Up @@ -455,14 +476,14 @@ def _input_format_classification(
def _reduce_stat_scores(
numerator: torch.Tensor,
denominator: torch.Tensor,
weights: Optional[torch.Tensor],
average: str,
mdmc_average: Optional[str],
weights: Optional[torch.Tensor] = None,
average: Optional[AverageMethod] = None,
mdmc_average: Optional[MDMCAverageMethod] = None,
zero_division: int = 0,
) -> torch.Tensor:
"""
Reduces scores of type ``numerator/denominator`` or
``weights * (numerator/denominator)``, if ``average='weighted'``.
``weights * (numerator/denominator)``, if ``average=AverageMethod.WEIGHTED``.
Args:
numerator: A tensor with numerator numbers.
Expand All @@ -472,16 +493,17 @@ def _reduce_stat_scores(
If the denominator is zero, then ``zero_division`` score will be
used for those elements.
weights:
A tensor of weights to be used if ``average='weighted'``.
A tensor of weights to be used if ``average=AverageMethod.WEIGHTED``.
average:
The method to average the scores. Should be one of ``'micro'``, ``'macro'``,
``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior
corresponds to `sklearn averaging methods <https://scikit-learn.org/stable/modules/\
The method to average the scores. Should be one of ``AverageMethod.MICRO``,
``AverageMethod.MACRO``, ``AverageMethod.WEIGHTED``, ``AverageMethod.NONE``,
``None`` or ``AverageMethod.SAMPLES``. The behavior corresponds to
`sklearn averaging methods <https://scikit-learn.org/stable/modules/\
model_evaluation.html#multiclass-and-multilabel-classification>`__.
mdmc_average:
The method to average the scores if inputs were multi-dimensional multi-class (MDMC).
Should be either ``'global'`` or ``'samplewise'``. If inputs were not
multi-dimensional multi-class, it should be ``None`` (default).
Should be either ``MDMCAverageMethod.GLOBAL`` or ``MDMCAverageMethod.SAMPLEWISE``. If
inputs were not multi-dimensional multi-class, it should be ``None`` (default).
zero_division:
The value to use for the score if denominator equals zero.
"""
Expand All @@ -498,19 +520,19 @@ def _reduce_stat_scores(
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)

if average not in ["micro", "none", None]:
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
weights = weights / weights.sum(dim=-1, keepdim=True)

scores = weights * (numerator / denominator)

# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)

if mdmc_average == "samplewise":
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()

if average in ["none", None]:
if average in (AverageMethod.NONE, None):
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
else:
scores = scores.sum()
Expand Down
20 changes: 6 additions & 14 deletions pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,9 @@

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, AverageMethod, DataType
from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum


class AverageMethods(LightningEnum):
""" Type of averages """
MACRO = 'macro'
WEIGHTED = 'weighted'
NONE = None


def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]:
Expand All @@ -51,7 +43,7 @@ def _auroc_compute(
mode: DataType,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
average: Optional[AverageMethod] = AverageMethod.MACRO,
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -91,18 +83,18 @@ def _auroc_compute(
auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)]

# calculate average
if average == AverageMethods.NONE:
if average in (AverageMethod.NONE, None):
return auc_scores
elif average == AverageMethods.MACRO:
elif average == AverageMethod.MACRO:
return torch.mean(torch.stack(auc_scores))
elif average == AverageMethods.WEIGHTED:
elif average == AverageMethod.WEIGHTED:
if mode == DataType.MULTILABEL:
support = torch.sum(target, dim=0)
else:
support = torch.bincount(target.flatten(), minlength=num_classes)
return torch.sum(torch.stack(auc_scores) * support / support.sum())

allowed_average = [e.value for e in AverageMethods]
allowed_average = [e.value for e in AverageMethod]
raise ValueError(f"Argument `average` expected to be one of the following:"
f" {allowed_average} but got {average}")

Expand Down

0 comments on commit ce1e528

Please sign in to comment.