Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove get_num_classes in jaccard_index #914

Merged
merged 11 commits into from
Mar 31, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853))
- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914))


- Added normalizer, tokenizer to ROUGE metric ([#838](https://github.com/PyTorchLightning/metrics/pull/838))
Expand Down
11 changes: 0 additions & 11 deletions tests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,3 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction
reduction=reduction,
)
assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))


def test_warning_on_difference_in_number_of_classes():
"""Test that warning is thrown if the detected number of classes are different from the the specified number of
classes."""
preds = torch.randint(3, (10,))
target = torch.randint(3, (10,))
with pytest.warns(
RuntimeWarning,
):
jaccard_index(preds, target, num_classes=4)
14 changes: 1 addition & 13 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce


Expand Down Expand Up @@ -92,18 +92,6 @@ def test_to_categorical():
assert torch.allclose(result, expected.to(result.dtype))


@pytest.mark.parametrize(
["preds", "target", "num_classes", "expected_num_classes"],
[
(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
],
)
def test_get_num_classes(preds, target, num_classes, expected_num_classes):
assert get_num_classes(preds, target, num_classes) == expected_num_classes


def test_flatten_list():
"""Check that _flatten utility function works as expected."""
inp = [[1, 2, 3], [4, 5], [6]]
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
from torchmetrics.utilities.data import get_num_classes
from torchmetrics.utilities.distributed import reduce


Expand Down Expand Up @@ -129,6 +128,5 @@ def jaccard_index(
tensor(0.9660)
"""

num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes)
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)
33 changes: 0 additions & 33 deletions torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.prints import rank_zero_warn

METRIC_EPS = 1e-6


Expand Down Expand Up @@ -145,37 +143,6 @@ def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor:
return torch.argmax(x, dim=argmax_dim)


def get_num_classes(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
) -> int:
"""Calculates the number of classes for a given prediction and target tensor.

Args:
preds: predicted values
target: true labels
num_classes: number of classes if known

Return:
An integer that represents the number of classes.
"""
num_target_classes = int(target.max().detach().item() + 1)
num_pred_classes = int(preds.max().detach().item() + 1)
num_all_classes = max(num_target_classes, num_pred_classes)

if num_classes is None:
num_classes = num_all_classes
elif num_classes != num_all_classes:
rank_zero_warn(
f"You have set {num_classes} number of classes which is"
f" different from predicted ({num_pred_classes}) and"
f" target ({num_target_classes}) number of classes",
RuntimeWarning,
)
return num_classes


def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
Expand Down