From be080dd22e77631290198f9f050526bd8bc83296 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Tue, 22 Oct 2024 02:42:24 -0700 Subject: [PATCH] retain type of `_modules` when renaming keys (#2793) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka B Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte --- CHANGELOG.md | 2 +- src/torchmetrics/collections.py | 15 +++++++------ src/torchmetrics/utilities/imports.py | 1 + tests/unittests/wrappers/test_multitask.py | 25 +++++++++++++--------- 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e439d426ae2..8736f561615 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Changing `_modules` dict type in Pytorch 2.5 preventing to fail collections metrics ([#2793](https://github.com/Lightning-AI/torchmetrics/pull/2793)) --- diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 85cfe251537..9fe0bb40761 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -14,7 +14,7 @@ # this is just a bypass for this module name collision with built-in one from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Hashable, Iterable, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -499,11 +499,12 @@ def _set_name(self, base: str) -> str: name = base if self.prefix is None else self.prefix + base return name if self.postfix is None else name + self.postfix - def _to_renamed_ordered_dict(self) -> OrderedDict: - od = OrderedDict() + def _to_renamed_dict(self) -> Mapping[str, Metric]: + # self._modules changed from OrderedDict to dict as of PyTorch 2.5.0 + dict_modules = OrderedDict() if isinstance(self._modules, OrderedDict) else {} for k, v in self._modules.items(): - od[self._set_name(k)] = v - return od + dict_modules[self._set_name(k)] = v + return dict_modules def __iter__(self) -> Iterator[Hashable]: """Return an iterator over the keys of the MetricDict.""" @@ -519,7 +520,7 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: """ if keep_base: return self._modules.keys() - return self._to_renamed_ordered_dict().keys() + return self._to_renamed_dict().keys() def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Metric]]: r"""Return an iterable of the ModuleDict key/value pairs. @@ -533,7 +534,7 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu self._compute_groups_create_state_ref(copy_state) if keep_base: return self._modules.items() - return self._to_renamed_ordered_dict().items() + return self._to_renamed_dict().items() def values(self, copy_state: bool = True) -> Iterable[Metric]: """Return an iterable of the ModuleDict values. diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 5f6fb7001a5..28bda373600 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -21,6 +21,7 @@ _PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" _TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0") _TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0") +_TORCH_GREATER_EQUAL_2_5 = RequirementCache("torch>=2.5.0") _TORCHMETRICS_GREATER_EQUAL_1_6 = RequirementCache("torchmetrics>=1.7.0") _NLTK_AVAILABLE = RequirementCache("nltk") diff --git a/tests/unittests/wrappers/test_multitask.py b/tests/unittests/wrappers/test_multitask.py index 069a4472d64..fb3ae8987cc 100644 --- a/tests/unittests/wrappers/test_multitask.py +++ b/tests/unittests/wrappers/test_multitask.py @@ -19,6 +19,7 @@ from torchmetrics import MetricCollection from torchmetrics.classification import BinaryAccuracy, BinaryF1Score from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from torchmetrics.wrappers import MultitaskWrapper from unittests import BATCH_SIZE, NUM_BATCHES @@ -90,13 +91,15 @@ def test_error_on_wrong_keys(): "Classification": BinaryAccuracy(), }) + order_dict = "" if _TORCH_GREATER_EQUAL_2_5 else "o" + with pytest.raises( ValueError, match=re.escape( - "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. " - "Found task_preds.keys() = dict_keys(['Classification']), task_targets.keys() = " - "dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = " - "odict_keys(['Classification', 'Regression'])" + "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`." + " Found task_preds.keys() = dict_keys(['Classification'])," + " task_targets.keys() = dict_keys(['Classification', 'Regression'])" + f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])" ), ): multitask_metrics.update(wrong_key_preds, _multitask_targets) @@ -104,9 +107,10 @@ def test_error_on_wrong_keys(): with pytest.raises( ValueError, match=re.escape( - "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. " - "Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = " - "dict_keys(['Classification']) and self.task_metrics.keys() = odict_keys(['Classification', 'Regression'])" + "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`." + " Found task_preds.keys() = dict_keys(['Classification', 'Regression'])," + " task_targets.keys() = dict_keys(['Classification'])" + f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])" ), ): multitask_metrics.update(_multitask_preds, wrong_key_targets) @@ -114,9 +118,10 @@ def test_error_on_wrong_keys(): with pytest.raises( ValueError, match=re.escape( - "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. " - "Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = " - "dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = odict_keys(['Classification'])" + "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`." + " Found task_preds.keys() = dict_keys(['Classification', 'Regression'])," + " task_targets.keys() = dict_keys(['Classification', 'Regression'])" + f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification'])" ), ): wrong_key_multitask_metrics.update(_multitask_preds, _multitask_targets)