Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove get_num_classes in jaccard_index #914

Merged
merged 11 commits into from
Mar 31, 2022
14 changes: 1 addition & 13 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce


Expand Down Expand Up @@ -92,18 +92,6 @@ def test_to_categorical():
assert torch.allclose(result, expected.to(result.dtype))


@pytest.mark.parametrize(
["preds", "target", "num_classes", "expected_num_classes"],
[
(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
],
)
def test_get_num_classes(preds, target, num_classes, expected_num_classes):
assert get_num_classes(preds, target, num_classes) == expected_num_classes


def test_flatten_list():
"""Check that _flatten utility function works as expected."""
inp = [[1, 2, 3], [4, 5], [6]]
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
from torchmetrics.utilities.data import get_num_classes
from torchmetrics.utilities.distributed import reduce


Expand Down Expand Up @@ -129,6 +128,5 @@ def jaccard_index(
tensor(0.9660)
"""

num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes)
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)
33 changes: 0 additions & 33 deletions torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.prints import rank_zero_warn

METRIC_EPS = 1e-6


Expand Down Expand Up @@ -145,37 +143,6 @@ def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor:
return torch.argmax(x, dim=argmax_dim)


def get_num_classes(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
) -> int:
"""Calculates the number of classes for a given prediction and target tensor.

Args:
preds: predicted values
target: true labels
num_classes: number of classes if known

Return:
An integer that represents the number of classes.
"""
num_target_classes = int(target.max().detach().item() + 1)
num_pred_classes = int(preds.max().detach().item() + 1)
num_all_classes = max(num_target_classes, num_pred_classes)

if num_classes is None:
num_classes = num_all_classes
elif num_classes != num_all_classes:
rank_zero_warn(
f"You have set {num_classes} number of classes which is"
f" different from predicted ({num_pred_classes}) and"
f" target ({num_target_classes}) number of classes",
RuntimeWarning,
)
return num_classes


def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
Expand Down