Skip to content

Commit

Permalink
retain type of _modules when renaming keys (#2793)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka B <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte <[email protected]>
(cherry picked from commit be080dd)
  • Loading branch information
bfolie authored and Borda committed Oct 22, 2024
1 parent 1591bd9 commit b1a2075
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Changing `_modules` dict type in Pytorch 2.5 preventing to fail collections metrics ([#2793](https://github.com/Lightning-AI/torchmetrics/pull/2793))


---
Expand Down
15 changes: 8 additions & 7 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# this is just a bypass for this module name collision with built-in one
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Hashable, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -499,11 +499,12 @@ def _set_name(self, base: str) -> str:
name = base if self.prefix is None else self.prefix + base
return name if self.postfix is None else name + self.postfix

def _to_renamed_ordered_dict(self) -> OrderedDict:
od = OrderedDict()
def _to_renamed_dict(self) -> Mapping[str, Metric]:
# self._modules changed from OrderedDict to dict as of PyTorch 2.5.0
dict_modules = OrderedDict() if isinstance(self._modules, OrderedDict) else {}
for k, v in self._modules.items():
od[self._set_name(k)] = v
return od
dict_modules[self._set_name(k)] = v
return dict_modules

def __iter__(self) -> Iterator[Hashable]:
"""Return an iterator over the keys of the MetricDict."""
Expand All @@ -519,7 +520,7 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]:
"""
if keep_base:
return self._modules.keys()
return self._to_renamed_ordered_dict().keys()
return self._to_renamed_dict().keys()

def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Metric]]:
r"""Return an iterable of the ModuleDict key/value pairs.
Expand All @@ -533,7 +534,7 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu
self._compute_groups_create_state_ref(copy_state)
if keep_base:
return self._modules.items()
return self._to_renamed_ordered_dict().items()
return self._to_renamed_dict().items()

def values(self, copy_state: bool = True) -> Iterable[Metric]:
"""Return an iterable of the ModuleDict values.
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_TORCH_GREATER_EQUAL_2_0 = RequirementCache("torch>=2.0.0")
_TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0")
_TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0")
_TORCH_GREATER_EQUAL_2_5 = RequirementCache("torch>=2.5.0")
_TORCHMETRICS_GREATER_EQUAL_1_6 = RequirementCache("torchmetrics>=1.7.0")

_NLTK_AVAILABLE = RequirementCache("nltk")
Expand Down
25 changes: 15 additions & 10 deletions tests/unittests/wrappers/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_5
from torchmetrics.wrappers import MultitaskWrapper

from unittests import BATCH_SIZE, NUM_BATCHES
Expand Down Expand Up @@ -90,33 +91,37 @@ def test_error_on_wrong_keys():
"Classification": BinaryAccuracy(),
})

order_dict = "" if _TORCH_GREATER_EQUAL_2_5 else "o"

with pytest.raises(
ValueError,
match=re.escape(
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. "
"Found task_preds.keys() = dict_keys(['Classification']), task_targets.keys() = "
"dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = "
"odict_keys(['Classification', 'Regression'])"
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`."
" Found task_preds.keys() = dict_keys(['Classification']),"
" task_targets.keys() = dict_keys(['Classification', 'Regression'])"
f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])"
),
):
multitask_metrics.update(wrong_key_preds, _multitask_targets)

with pytest.raises(
ValueError,
match=re.escape(
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. "
"Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = "
"dict_keys(['Classification']) and self.task_metrics.keys() = odict_keys(['Classification', 'Regression'])"
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`."
" Found task_preds.keys() = dict_keys(['Classification', 'Regression']),"
" task_targets.keys() = dict_keys(['Classification'])"
f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])"
),
):
multitask_metrics.update(_multitask_preds, wrong_key_targets)

with pytest.raises(
ValueError,
match=re.escape(
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. "
"Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = "
"dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = odict_keys(['Classification'])"
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`."
" Found task_preds.keys() = dict_keys(['Classification', 'Regression']),"
" task_targets.keys() = dict_keys(['Classification', 'Regression'])"
f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification'])"
),
):
wrong_key_multitask_metrics.update(_multitask_preds, _multitask_targets)
Expand Down

0 comments on commit b1a2075

Please sign in to comment.