From e8a0a5a02c1c7b39972a6e4e9b7d2f674c0a74a3 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 10 May 2022 14:06:01 +0200 Subject: [PATCH 1/2] allclose --- tests/bases/test_collections.py | 30 +++++++++++++++++++++++++++++- torchmetrics/collections.py | 6 +++--- torchmetrics/utilities/data.py | 7 +++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 884b0fdc3d9..230c045e203 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -20,7 +20,17 @@ from tests.helpers import seed_all from tests.helpers.testers import DummyMetricDiff, DummyMetricSum -from torchmetrics import Accuracy, CohenKappa, ConfusionMatrix, F1Score, Metric, MetricCollection, Precision, Recall +from torchmetrics import ( + Accuracy, + CohenKappa, + ConfusionMatrix, + F1Score, + MatthewsCorrCoef, + Metric, + MetricCollection, + Precision, + Recall, +) seed_all(42) @@ -406,6 +416,24 @@ def test_compute_group_define_by_user(): assert m.compute() +def test_compute_on_different_dtype(): + """Check that extraction of compute groups are robust towards difference in dtype.""" + m = MetricCollection( + [ + ConfusionMatrix(num_classes=3), + MatthewsCorrCoef(num_classes=3), + ] + ) + assert not m._groups_checked + assert m.compute_groups == {0: ["ConfusionMatrix"], 1: ["MatthewsCorrCoef"]} + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + for _ in range(2): + m.update(preds, target) + assert m.compute_groups == {0: ["ConfusionMatrix", "MatthewsCorrCoef"]} + assert m.compute() + + def test_error_on_wrong_specified_compute_groups(): """Test that error is raised if user mis-specify the compute groups.""" with pytest.raises(ValueError, match="Input Accuracy in `compute_groups`.*"): diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 5132c9b1230..ffbc60aa24f 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -20,7 +20,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import _flatten_dict +from torchmetrics.utilities.data import _flatten_dict, allclose # this is just a bypass for this module name collision with build-in one from torchmetrics.utilities.imports import OrderedDict @@ -231,10 +231,10 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: return False if isinstance(state1, Tensor) and isinstance(state2, Tensor): - return state1.shape == state2.shape and torch.allclose(state1, state2) + return state1.shape == state2.shape and allclose(state1, state2) if isinstance(state1, list) and isinstance(state2, list): - return all(s1.shape == s2.shape and torch.allclose(s1, s2) for s1, s2 in zip(state1, state2)) + return all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)) return True diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 5a991138105..fb2409ad4bb 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -248,3 +248,10 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: return output else: return torch.bincount(x, minlength=minlength) + + +def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: + """Wrapper of torch.allclose that is robust towards dtype difference.""" + if tensor1.dtype != tensor2.dtype: + tensor2 = tensor2.to(dtype=tensor1.dtype) + return torch.allclose(tensor1, tensor2) From e9e4d633f51200634dc30a79220f3d806baa8588 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 10 May 2022 14:10:56 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4d6354c808..ac6e30b9f4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed non-empty state dict for a few metrics ([#1012](https://github.com/PyTorchLightning/metrics/pull/1012)) +- Fixed bug when comparing states while finding compute groups ([#1022](https://github.com/PyTorchLightning/metrics/pull/1022)) + ## [0.8.2] - 2022-05-06