Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added micro average option for torch metrics #874

Merged
merged 23 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a6cc8aa
ENH: added micro average option for torch metrics
razmikmelikbekyan Mar 6, 2022
6fc448b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2022
18bd44c
Merge branch 'master' into add_micro_average_IOU
Borda Mar 20, 2022
b9746c9
Merge branch 'master' into add_micro_average_IOU
razmikmelikbekyan Mar 21, 2022
abf5a24
Merge branch 'master' into add_micro_average_IOU
Borda Mar 24, 2022
b23bf99
Merge branch 'master' into add_micro_average_IOU
Borda Mar 31, 2022
a576630
Merge branch 'master' into add_micro_average_IOU
Borda Apr 11, 2022
d45d8bd
Merge branch 'master' into add_micro_average_IOU
justusschock May 5, 2022
0d38a8c
Merge branch 'master' into add_micro_average_IOU
SkafteNicki May 10, 2022
68c21fd
Merge branch 'master' into add_micro_average_IOU
razmikmelikbekyan May 12, 2022
52dbc66
Merge branch 'master' into add_micro_average_IOU
Borda May 23, 2022
2d7d669
Merge branch 'master' into add_micro_average_IOU
SkafteNicki May 24, 2022
c093184
Merge branch 'master' into add_micro_average_IOU
Borda May 25, 2022
e1ba857
fix flake
SkafteNicki May 25, 2022
dec7db6
remove reduction
SkafteNicki May 25, 2022
b9345ee
changelog
SkafteNicki May 25, 2022
d88500c
fix docs
SkafteNicki May 25, 2022
68cd827
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 25, 2022
bad4272
flake8
SkafteNicki May 25, 2022
3b4bf33
if/return
Borda May 25, 2022
344c10c
fix integer division
SkafteNicki May 25, 2022
1008e88
Merge branch 'master' into add_micro_average_IOU
Borda May 25, 2022
5435cec
Merge branch 'add_micro_average_IOU' of https://github.com/razmikmeli…
SkafteNicki May 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Renamed `reduction` argument to `average` in Jaccard score and added additional options ([#874](https://github.com/PyTorchLightning/metrics/pull/874))


### Deprecated
Expand Down
43 changes: 23 additions & 20 deletions tests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None):
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)


@pytest.mark.parametrize("reduction", ["elementwise_mean", "none"])
@pytest.mark.parametrize("average", [None, "macro", "micro", "weighted"])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[
Expand All @@ -104,60 +104,61 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None):
class TestJaccardIndex(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
average = "macro" if reduction == "elementwise_mean" else None # convert tags
def test_jaccard(self, average, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
# average = "macro" if reduction == "elementwise_mean" else None # convert tags
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=JaccardIndex,
sk_metric=partial(sk_metric, average=average),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction},
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average},
)

def test_jaccard_functional(self, reduction, preds, target, sk_metric, num_classes):
average = "macro" if reduction == "elementwise_mean" else None # convert tags
def test_jaccard_functional(self, average, preds, target, sk_metric, num_classes):
# average = "macro" if reduction == "elementwise_mean" else None # convert tags
self.run_functional_metric_test(
preds,
target,
metric_functional=jaccard_index,
sk_metric=partial(sk_metric, average=average),
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction},
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average},
)

def test_jaccard_differentiability(self, reduction, preds, target, sk_metric, num_classes):
def test_jaccard_differentiability(self, average, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=JaccardIndex,
metric_functional=jaccard_index,
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction},
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average},
)


@pytest.mark.parametrize(
["half_ones", "reduction", "ignore_index", "expected"],
["half_ones", "average", "ignore_index", "expected"],
[
(False, "none", None, Tensor([1, 1, 1])),
(False, "elementwise_mean", None, Tensor([1])),
(False, "macro", None, Tensor([1])),
(False, "none", 0, Tensor([1, 1])),
(True, "none", None, Tensor([0.5, 0.5, 0.5])),
(True, "elementwise_mean", None, Tensor([0.5])),
(True, "macro", None, Tensor([0.5])),
(True, "none", 0, Tensor([2 / 3, 1 / 2])),
],
)
def test_jaccard(half_ones, reduction, ignore_index, expected):
def test_jaccard(half_ones, average, ignore_index, expected):
preds = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
preds[:60] = 1
jaccard_val = jaccard_index(
preds=preds,
target=target,
average=average,
num_classes=3,
ignore_index=ignore_index,
reduction=reduction,
# reduction=reduction,
justusschock marked this conversation as resolved.
Show resolved Hide resolved
)
assert torch.allclose(jaccard_val, expected, atol=1e-9)

Expand Down Expand Up @@ -199,18 +200,19 @@ def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_clas
jaccard_val = jaccard_index(
preds=tensor(pred),
target=tensor(target),
average=None,
ignore_index=ignore_index,
absent_score=absent_score,
num_classes=num_classes,
reduction="none",
# reduction="none",
)
assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(
["pred", "target", "ignore_index", "num_classes", "reduction", "expected"],
["pred", "target", "ignore_index", "num_classes", "average", "expected"],
[
# Ignoring an index outside of [0, num_classes-1] should have no effect.
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]),
Expand All @@ -221,16 +223,17 @@ def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_clas
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "elementwise_mean", [7 / 12]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "macro", [7 / 12]),
# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]),
],
)
def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, expected):
jaccard_val = jaccard_index(
preds=tensor(pred),
target=tensor(target),
average=average,
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
# reduction=reduction,
)
assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))
29 changes: 19 additions & 10 deletions torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat
Expand Down Expand Up @@ -45,6 +44,18 @@ class JaccardIndex(ConfusionMatrix):

Args:
num_classes: Number of classes in the dataset.
average:
Defines the reduction that is applied. Should be one of the following:

- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class. Note that if a given class doesn't occur in the
`preds` or `target`, the value for the class will be ``nan``.

ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
range [0, num_classes-1]. By default, no index is ignored, and all classes are used.
Expand All @@ -53,12 +64,6 @@ class JaccardIndex(ConfusionMatrix):
[0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be assigned the `absent_score`.
threshold: Threshold value for binary or multi-label probabilities.
multilabel: determines if data is multilabel or not.
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
Expand All @@ -78,11 +83,11 @@ class JaccardIndex(ConfusionMatrix):
def __init__(
self,
num_classes: int,
average: Optional[str] = "macro",
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
multilabel: bool = False,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
Expand All @@ -92,12 +97,16 @@ def __init__(
multilabel=multilabel,
**kwargs,
)
self.reduction = reduction
self.average = average
self.ignore_index = ignore_index
self.absent_score = absent_score

def compute(self) -> Tensor:
"""Computes intersection over union (IoU)"""
return _jaccard_from_confmat(
self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction
self.confmat,
self.num_classes,
self.average,
self.ignore_index,
self.absent_score,
)
101 changes: 68 additions & 33 deletions torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,90 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
from torchmetrics.utilities.distributed import reduce


def _jaccard_from_confmat(
confmat: Tensor,
num_classes: int,
average: Optional[str] = "macro",
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Computes the intersection over union from confusion matrix.

Args:
confmat: Confusion matrix without normalization
num_classes: Number of classes for a given prediction and target tensor
average:
Defines the reduction that is applied. Should be one of the following:

- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class. Note that if a given class doesn't occur in the
`preds` or `target`, the value for the class will be ``nan``.

ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method.
absent_score: score to use for an individual class, if no instances of the class index were present in ``preds``
AND no instances of the class index were present in ``target``.
reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied
absent_score: score to use for an individual class, if no instances of the class index were present in `pred`
AND no instances of the class index were present in `target`.
"""
allowed_average = ["micro", "macro", "weighted", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
confmat[ignore_index] = 0.0

intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = intersection.float() / union.float()
scores[union == 0] = absent_score

if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat(
[
scores[:ignore_index],
scores[ignore_index + 1 :],
]
if average == "none" or average is None:
intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = intersection.float() / union.float()
scores[union == 0] = absent_score

if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat(
[
scores[:ignore_index],
scores[ignore_index + 1 :],
]
)
return scores

if average == "macro":
scores = _jaccard_from_confmat(
confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score
)
return torch.mean(scores)

return reduce(scores, reduction=reduction)
if average == "micro":
intersection = torch.sum(torch.diag(confmat))
union = torch.sum(torch.sum(confmat, dim=1) + torch.sum(confmat, dim=0) - torch.diag(confmat))
return intersection.float() / union.float()

weights = torch.sum(confmat, dim=1).float() / torch.sum(confmat).float()
scores = _jaccard_from_confmat(
confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score
)
return torch.sum(weights * scores)


def jaccard_index(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[str] = "macro",
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
r"""Computes `Jaccard index`_

Expand All @@ -95,6 +120,18 @@ def jaccard_index(
preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]``
target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]``
num_classes: Specify the number of classes
average:
Defines the reduction that is applied. Should be one of the following:

- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class. Note that if a given class doesn't occur in the
`preds` or `target`, the value for the class will be ``nan``.

ignore_index: optional int specifying a target class to ignore. If given,
this class index does not contribute to the returned score, regardless
of reduction method. Has no effect if given an int that is not in the
Expand All @@ -106,15 +143,13 @@ def jaccard_index(
[0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be
assigned the `absent_score`.
threshold: Threshold value for binary or multi-label probabilities.
reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied

Return:
IoU score: Tensor containing single value if reduction is
'elementwise_mean', or number of classes if reduction is 'none'
The shape of the returned tensor depends on the ``average`` parameter

- If ``average in ['micro', 'macro', 'weighted']``, a one-element tensor will be returned
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes

Example:
>>> from torchmetrics.functional import jaccard_index
Expand All @@ -126,4 +161,4 @@ def jaccard_index(
"""

confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)
return _jaccard_from_confmat(confmat, num_classes, average, ignore_index, absent_score)