diff --git a/CHANGELOG.md b/CHANGELOG.md index 16aec20dee1..55742672139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,7 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853)) +- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914)) - Added normalizer, tokenizer to ROUGE metric ([#838](https://github.com/PyTorchLightning/metrics/pull/838)) diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index 2b60b5f3075..e9b8deb438d 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -234,14 +234,3 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction reduction=reduction, ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) - - -def test_warning_on_difference_in_number_of_classes(): - """Test that warning is thrown if the detected number of classes are different from the the specified number of - classes.""" - preds = torch.randint(3, (10,)) - target = torch.randint(3, (10,)) - with pytest.warns( - RuntimeWarning, - ): - jaccard_index(preds, target, num_classes=4) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 67d948dc7ca..40dd37fe1e3 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -16,7 +16,7 @@ from torch import tensor from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn -from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot +from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -92,18 +92,6 @@ def test_to_categorical(): assert torch.allclose(result, expected.to(result.dtype)) -@pytest.mark.parametrize( - ["preds", "target", "num_classes", "expected_num_classes"], - [ - (torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), - (torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - (torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - ], -) -def test_get_num_classes(preds, target, num_classes, expected_num_classes): - assert get_num_classes(preds, target, num_classes) == expected_num_classes - - def test_flatten_list(): """Check that _flatten utility function works as expected.""" inp = [[1, 2, 3], [4, 5], [6]] diff --git a/torchmetrics/functional/classification/jaccard.py b/torchmetrics/functional/classification/jaccard.py index e9a44be64f3..0e5f7dc51fc 100644 --- a/torchmetrics/functional/classification/jaccard.py +++ b/torchmetrics/functional/classification/jaccard.py @@ -18,7 +18,6 @@ from typing_extensions import Literal from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update -from torchmetrics.utilities.data import get_num_classes from torchmetrics.utilities.distributed import reduce @@ -129,6 +128,5 @@ def jaccard_index( tensor(0.9660) """ - num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes) confmat = _confusion_matrix_update(preds, target, num_classes, threshold) return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index cb4c9c8a609..6495d66d779 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -16,8 +16,6 @@ import torch from torch import Tensor, tensor -from torchmetrics.utilities.prints import rank_zero_warn - METRIC_EPS = 1e-6 @@ -145,37 +143,6 @@ def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor: return torch.argmax(x, dim=argmax_dim) -def get_num_classes( - preds: Tensor, - target: Tensor, - num_classes: Optional[int] = None, -) -> int: - """Calculates the number of classes for a given prediction and target tensor. - - Args: - preds: predicted values - target: true labels - num_classes: number of classes if known - - Return: - An integer that represents the number of classes. - """ - num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(preds.max().detach().item() + 1) - num_all_classes = max(num_target_classes, num_pred_classes) - - if num_classes is None: - num_classes = num_all_classes - elif num_classes != num_all_classes: - rank_zero_warn( - f"You have set {num_classes} number of classes which is" - f" different from predicted ({num_pred_classes}) and" - f" target ({num_target_classes}) number of classes", - RuntimeWarning, - ) - return num_classes - - def apply_to_collection( data: Any, dtype: Union[type, tuple],