Skip to content

Commit

Permalink
[bugfix] MetricCollection should return metrics with prefix on items(…
Browse files Browse the repository at this point in the history
…), keys() (#209)

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update print

* resolve flake8

* update

* CI: update pre-commit (#207)

* update pre-commit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>

* fix setup imports (#208)

* fix setup imports

* pkg

* chlog

* format

* update

* update

* check type

* update on comments

* update on comments

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: jirka <[email protected]>
  • Loading branch information
6 people authored May 4, 2021
1 parent 7a9e50a commit 67ce961
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- `MetricCollection` should return metrics with prefix on `items()`, `keys()` ([#209](https://github.com/PyTorchLightning/metrics/pull/209))


- Calling `compute` before `update` will now give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164))


Expand Down
47 changes: 47 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,60 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix):
for name in names:
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'

for k, _ in new_metric_collection.items():
assert 'new_prefix_' in k

for k in new_metric_collection.keys():
assert 'new_prefix_' in k

for k, _ in new_metric_collection.items(keep_base=True):
assert 'new_prefix_' not in k

for k in new_metric_collection.keys(keep_base=True):
assert 'new_prefix_' not in k

assert type(new_metric_collection.keys(keep_base=True)) == type(new_metric_collection.keys(keep_base=False)) # noqa E721
assert type(new_metric_collection.items(keep_base=True)) == type(new_metric_collection.items(keep_base=False)) # noqa E721

new_metric_collection = new_metric_collection.clone(postfix='_new_postfix')
out = new_metric_collection(5)
names = [n[:-len(postfix)] if postfix is not None else n for n in names] # strip away old postfix
for name in names:
assert f"new_prefix_{name}_new_postfix" in out, 'postfix argument not working as intended with clone method'


def test_metric_collection_repr():
"""
Test MetricCollection
"""

class A(DummyMetricSum):
pass

class B(DummyMetricDiff):
pass

m1 = A()
m2 = B()
metric_collection = MetricCollection([m1, m2], prefix=None, postfix=None)

expected = "MetricCollection(\n (A): A()\n (B): B()\n)"
assert metric_collection.__repr__() == expected

metric_collection = MetricCollection([m1, m2], prefix="a", postfix=None)

expected = 'MetricCollection(\n (A): A()\n (B): B(),\n prefix=a\n)'
assert metric_collection.__repr__() == expected

metric_collection = MetricCollection([m1, m2], prefix=None, postfix="a")
expected = 'MetricCollection(\n (A): A()\n (B): B(),\n postfix=a\n)'
assert metric_collection.__repr__() == expected

metric_collection = MetricCollection([m1, m2], prefix="a", postfix="b")
expected = 'MetricCollection(\n (A): A()\n (B): B(),\n prefix=a,\n postfix=b\n)'
assert metric_collection.__repr__() == expected


def test_metric_collection_same_order():
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
Expand Down
45 changes: 39 additions & 6 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Optional, Sequence, Union
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union

from torch import nn

Expand Down Expand Up @@ -144,24 +145,24 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {self._set_name(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Iteratively call update for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items():
for _, m in self.items(keep_base=True):
m_kwargs = m._filter_kwargs(**kwargs)
m.update(*args, **m_kwargs)

def compute(self) -> Dict[str, Any]:
return {self._set_name(k): m.compute() for k, m in self.items()}
return {k: m.compute() for k, m in self.items()}

def reset(self) -> None:
""" Iteratively call reset for each metric """
for _, m in self.items():
for _, m in self.items(keep_base=True):
m.reset()

def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> 'MetricCollection':
Expand All @@ -182,16 +183,48 @@ def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
for _, m in self.items(keep_base=True):
m.persistent(mode)

def _set_name(self, base: str) -> str:
name = base if self.prefix is None else self.prefix + base
name = name if self.postfix is None else name + self.postfix
return name

def _to_renamed_ordered_dict(self) -> OrderedDict:
od = OrderedDict()
for k, v in self._modules.items():
od[self._set_name(k)] = v
return od

def keys(self, keep_base: bool = False):
r"""Return an iterable of the ModuleDict key.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
"""
if keep_base:
return self._modules.keys()
return self._to_renamed_ordered_dict().keys()

def items(self, keep_base: bool = False) -> Iterable[Tuple[str, nn.Module]]:
r"""Return an iterable of the ModuleDict key/value pairs.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
"""
if keep_base:
return self._modules.items()
return self._to_renamed_ordered_dict().items()

@staticmethod
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
if arg is None or isinstance(arg, str):
return arg
raise ValueError(f'Expected input `{name}` to be a string, but got {type(arg)}')

def __repr__(self) -> Optional[str]:
repr = super().__repr__()[:-2]
if self.prefix:
repr += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
if self.postfix:
repr += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
return repr + "\n)"

0 comments on commit 67ce961

Please sign in to comment.