Skip to content

Commit

Permalink
[Metrics] MetricCollection (#4318)
Browse files Browse the repository at this point in the history
* docs + precision + recall + f_beta + refactor

Co-authored-by: Teddy Koker <[email protected]>

* rebase

Co-authored-by: Teddy Koker <[email protected]>

* fixes

Co-authored-by: Teddy Koker <[email protected]>

* added missing file

* docs

* docs

* extra import

* add metric collection

* add docs + integration with log_dict

* add test

* update

* update

* more test

* more test

* pep8

* fix doctest

* pep8

* add clone method

* add clone method

* merge-2

* changelog

* kwargs filtering and tests

* pep8

* fix test

* update docs

* Update docs/source/metrics.rst

Co-authored-by: Roger Shieh <[email protected]>

* fix docs

* fix tests

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>

* fix docs

* fix doctest

* fix doctest

* fix doctest

* fix doctest

Co-authored-by: ananyahjha93 <[email protected]>
Co-authored-by: Teddy Koker <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Roger Shieh <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
6 people authored Jan 8, 2021
1 parent 06f3609 commit 06668c0
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 14 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056))


- Added `MetricCollection` ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))


- Added `.clone()` method to metrics ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))


### Changed

- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218))
Expand Down
77 changes: 75 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
.. note::

If using metrics in data parallel mode (dp), the metric update/logging should be done
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
or ``test``). This is due to metric states else being destroyed after each forward pass,
Expand All @@ -99,7 +100,6 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:

.. code-block:: python
Expand Down Expand Up @@ -131,7 +131,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
Metrics contain internal states that keep track of the data seen so far.
Do not mix metric states across training, validation and testing.
It is highly recommended to re-initialize the metric per mode as
shown in the examples above.
shown in the examples above. For easy initializing the same metric multiple
times, the ``.clone()`` method can be used:

.. testcode::

def __init__(self):
...
metric = pl.metrics.Accuracy()
self.train_acc = metric.clone()
self.val_acc = metric.clone()
self.test_acc = metric.clone()

.. note::

Expand Down Expand Up @@ -240,6 +250,69 @@ In practise this means that:
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated
****************
MetricCollection
****************

In many cases it is beneficial to evaluate the model output by multiple metrics.
In this case the `MetricCollection` class may come in handy. It accepts a sequence
of metrics and wraps theses into a single callable metric class, with the same
interface as any other metric.

Example:

.. testcode::

from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
Accuracy(),
Precision(num_classes=3, average='macro'),
Recall(num_classes=3, average='macro')
])
print(metric_collection(preds, target))

.. testoutput::
:options: +NORMALIZE_WHITESPACE

{'Accuracy': tensor(0.1250),
'Precision': tensor(0.0667),
'Recall': tensor(0.1111)}

Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule

.. code-block:: python
def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
self.valid_metrics = metrics.clone()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
.. note::

`MetricCollection` as default assumes that all the metrics in the collection
have the same call signature. If this is not the case, input that should be
given to different metrics can given as keyword arguments to the collection.

.. autoclass:: pytorch_lightning.metrics.MetricCollection
:noindex:

**********
Metric API
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.metric import Metric # noqa: F401
from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401

from pytorch_lightning.metrics.classification import ( # noqa: F401
Accuracy,
Expand Down
117 changes: 112 additions & 5 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -57,6 +58,7 @@ class Metric(nn.Module, ABC):
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
"""

def __init__(
self,
compute_on_step: bool = True,
Expand All @@ -72,6 +74,7 @@ def __init__(
self.dist_sync_fn = dist_sync_fn
self._to_sync = True

self._update_signature = inspect.signature(self.update)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
Expand Down Expand Up @@ -120,7 +123,7 @@ def add_state(
"""
if (
not isinstance(default, torch.Tensor)
and not isinstance(default, list) # noqa: W503
and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError(
Expand Down Expand Up @@ -208,9 +211,11 @@ def wrapped_func(*args, **kwargs):
return self._computed

dist_sync_fn = self.dist_sync_fn
if (dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()):
if (
dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
# User provided a bool, so we assume DDP if available
dist_sync_fn = gather_all_tensors

Expand Down Expand Up @@ -250,6 +255,10 @@ def reset(self):
else:
setattr(self, attr, deepcopy(default))

def clone(self):
""" Make a copy of the metric """
return deepcopy(self)

def __getstate__(self):
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
Expand Down Expand Up @@ -292,3 +301,101 @@ def state_dict(self, *args, **kwargs):
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict


class MetricCollection(nn.ModuleDict):
"""
MetricCollection class can be used to chain metrics that have the same
call pattern into one single class.
Args:
metrics: One of the following
* list or tuple: if metrics are passed in as a list, will use the
metrics class name as key for output dict. Therefore, two metrics
of the same class cannot be chained this way.
* dict: if metrics are passed in as a dict, will use each key in the
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.
Example (input as list):
>>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
>>> metrics = MetricCollection([Accuracy(),
... Precision(num_classes=3, average='macro'),
... Recall(num_classes=3, average='macro')])
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
Example (input as dict):
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
... 'macro_recall': Recall(num_classes=3, average='macro')})
>>> metrics(preds, target)
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}
"""
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
super().__init__()
if isinstance(metrics, dict):
# Check all values are metrics
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise ValueError(f'Value {metric} belonging to key {name}'
' is not an instance of `pl.metrics.Metric`')
self[name] = metric
elif isinstance(metrics, (tuple, list)):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(f'Input {metric} to `MetricCollection` is not a instance'
' of `pl.metrics.Metric`')
name = metric.__class__.__name__
if name in self:
raise ValueError(f'Encountered two metrics both named {name}')
self[name] = metric
else:
raise ValueError('Unknown input to MetricCollection.')

def _filter_kwargs(self, metric: Metric, **kwargs):
""" filter kwargs such that they match the update signature of the metric """
return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()}

def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Iteratively call update for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items():
m_kwargs = self._filter_kwargs(m, **kwargs)
m.update(*args, **m_kwargs)

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.items()}

def reset(self):
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()

def clone(self):
""" Make a copy of the metric collection """
return deepcopy(self)

def persistent(self, mode: bool = True):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
m.persistent(mode)
Loading

0 comments on commit 06668c0

Please sign in to comment.