Skip to content

Commit

Permalink
Add prefix arg to metric collection (#70)
Browse files Browse the repository at this point in the history
* prefix arg

* prefix arg

* Apply suggestions from code review

* chlog

* add types

* fix doctest

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Mar 14, 2021
1 parent 0b9553d commit c32d3e5
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 30 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))


### Changed

Expand Down
44 changes: 24 additions & 20 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,27 +195,31 @@ Example:
Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule

.. code-block:: python
.. testcode::

from torchmetrics import Accuracy, MetricCollection, Precision, Recall

def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
self.valid_metrics = metrics.clone()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
class MyModule():
def __init__(self):
metrics = MetricCollection(Accuracy(), Precision(), Recall())
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')

def training_step(self, batch, batch_idx):
logits = self(x)
# ...
output = self.train_metrics(logits, y)
# use log_dict instead of log
# metrics are logged with keys: train_Accuracy, train_Precision and train_Recall
self.log_dict(output)

def validation_step(self, batch, batch_idx):
logits = self(x)
# ...
output = self.valid_metrics(logits, y)
# use log_dict instead of log
# metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
self.log_dict(output)

.. note::

Expand Down
25 changes: 25 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,28 @@ def test_metric_collection_args_kwargs(tmpdir):
_ = metric_collection(x=10, y=20)
assert metric_collection['DummyMetricSum'].x == 10
assert metric_collection['DummyMetricDiff'].x == -20


def test_metric_collection_prefix_arg(tmpdir):
""" Test that the prefix arg alters the keywords in the output"""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
names = ['DummyMetricSum', 'DummyMetricDiff']

metric_collection = MetricCollection([m1, m2], prefix='prefix_')

# test forward
out = metric_collection(5)
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'

# test compute
out = metric_collection.compute()
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'

# test clone
new_metric_collection = metric_collection.clone(prefix='new_prefix_')
out = new_metric_collection(5)
for name in names:
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'
43 changes: 33 additions & 10 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from torch import nn

Expand All @@ -36,6 +36,8 @@ class MetricCollection(nn.ModuleDict):
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.
prefix: a string to append in front of the keys of the output dict
Example (input as list):
>>> import torch
>>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall
Expand All @@ -58,8 +60,11 @@ class MetricCollection(nn.ModuleDict):
>>> metrics.persistent()
"""

def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
def __init__(
self,
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
prefix: Optional[str] = None
):
super().__init__()
if isinstance(metrics, dict):
# Check all values are metrics
Expand All @@ -84,13 +89,15 @@ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]
else:
raise ValueError("Unknown input to MetricCollection.")

self.prefix = self._check_prefix_arg(prefix)

def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Expand All @@ -103,20 +110,36 @@ def update(self, *args, **kwargs): # pylint: disable=E0202
m.update(*args, **m_kwargs)

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.items()}
return {self._set_prefix(k): m.compute() for k, m in self.items()}

def reset(self):
def reset(self) -> None:
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()

def clone(self):
""" Make a copy of the metric collection """
return deepcopy(self)
def clone(self, prefix: Optional[str] = None) -> 'MetricCollection':
""" Make a copy of the metric collection
Args:
prefix: a string to append in front of the metric keys
"""
mc = deepcopy(self)
mc.prefix = self._check_prefix_arg(prefix)
return mc

def persistent(self, mode: bool = True):
def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
m.persistent(mode)

def _set_prefix(self, k: str) -> str:
return k if self.prefix is None else self.prefix + k

def _check_prefix_arg(self, prefix: str) -> Optional[str]:
if prefix is not None:
if isinstance(prefix, str):
return prefix
else:
raise ValueError('Expected input `prefix` to be a string')
return None

0 comments on commit c32d3e5

Please sign in to comment.