From f750888ff3c402d86e7618c0744992c30f3f50cd Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 25 Mar 2022 07:06:09 +0100 Subject: [PATCH 1/4] remove get_num_classes --- tests/test_utilities.py | 14 +-------- .../functional/classification/jaccard.py | 2 -- torchmetrics/utilities/data.py | 31 ------------------- 3 files changed, 1 insertion(+), 46 deletions(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 67d948dc7ca..40dd37fe1e3 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -16,7 +16,7 @@ from torch import tensor from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn -from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot +from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -92,18 +92,6 @@ def test_to_categorical(): assert torch.allclose(result, expected.to(result.dtype)) -@pytest.mark.parametrize( - ["preds", "target", "num_classes", "expected_num_classes"], - [ - (torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), - (torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - (torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - ], -) -def test_get_num_classes(preds, target, num_classes, expected_num_classes): - assert get_num_classes(preds, target, num_classes) == expected_num_classes - - def test_flatten_list(): """Check that _flatten utility function works as expected.""" inp = [[1, 2, 3], [4, 5], [6]] diff --git a/torchmetrics/functional/classification/jaccard.py b/torchmetrics/functional/classification/jaccard.py index e9a44be64f3..0e5f7dc51fc 100644 --- a/torchmetrics/functional/classification/jaccard.py +++ b/torchmetrics/functional/classification/jaccard.py @@ -18,7 +18,6 @@ from typing_extensions import Literal from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update -from torchmetrics.utilities.data import get_num_classes from torchmetrics.utilities.distributed import reduce @@ -129,6 +128,5 @@ def jaccard_index( tensor(0.9660) """ - num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes) confmat = _confusion_matrix_update(preds, target, num_classes, threshold) return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index cb4c9c8a609..3351cc82d49 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -145,37 +145,6 @@ def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor: return torch.argmax(x, dim=argmax_dim) -def get_num_classes( - preds: Tensor, - target: Tensor, - num_classes: Optional[int] = None, -) -> int: - """Calculates the number of classes for a given prediction and target tensor. - - Args: - preds: predicted values - target: true labels - num_classes: number of classes if known - - Return: - An integer that represents the number of classes. - """ - num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(preds.max().detach().item() + 1) - num_all_classes = max(num_target_classes, num_pred_classes) - - if num_classes is None: - num_classes = num_all_classes - elif num_classes != num_all_classes: - rank_zero_warn( - f"You have set {num_classes} number of classes which is" - f" different from predicted ({num_pred_classes}) and" - f" target ({num_target_classes}) number of classes", - RuntimeWarning, - ) - return num_classes - - def apply_to_collection( data: Any, dtype: Union[type, tuple], From 5e6df2ab0fde318f4ada9c9e76a151d6a3e24e28 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 25 Mar 2022 07:11:14 +0100 Subject: [PATCH 2/4] linting --- torchmetrics/utilities/data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 3351cc82d49..6495d66d779 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -16,8 +16,6 @@ import torch from torch import Tensor, tensor -from torchmetrics.utilities.prints import rank_zero_warn - METRIC_EPS = 1e-6 From 1fa30f1c7ceb7e1e102fd16a83f5e78e182723ad Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 25 Mar 2022 07:22:01 +0100 Subject: [PATCH 3/4] add note to changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2843270fd5b..f7605e0d3e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,7 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853)) +- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914)) - Added normalizer, tokenizer to ROUGE metric ([#838](https://github.com/PyTorchLightning/metrics/pull/838)) From cf17f9a74b32b1f475e7e7d23d963ed831a7a128 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 25 Mar 2022 08:09:48 +0100 Subject: [PATCH 4/4] remove test --- tests/classification/test_jaccard.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index 2b60b5f3075..e9b8deb438d 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -234,14 +234,3 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction reduction=reduction, ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) - - -def test_warning_on_difference_in_number_of_classes(): - """Test that warning is thrown if the detected number of classes are different from the the specified number of - classes.""" - preds = torch.randint(3, (10,)) - target = torch.randint(3, (10,)) - with pytest.warns( - RuntimeWarning, - ): - jaccard_index(preds, target, num_classes=4)