Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] MetricCollection #4318

Merged
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
dd9c584
docs + precision + recall + f_beta + refactor
ananyahjha93 Oct 10, 2020
1ef3ef2
rebase
ananyahjha93 Oct 10, 2020
f7c5c2d
fixes
ananyahjha93 Oct 10, 2020
f810486
added missing file
ananyahjha93 Oct 10, 2020
3ca8123
docs
ananyahjha93 Oct 10, 2020
344a518
docs
ananyahjha93 Oct 10, 2020
f2f7ec9
extra import
ananyahjha93 Oct 10, 2020
3b5c4ff
add metric collection
SkafteNicki Oct 12, 2020
2875077
add docs + integration with log_dict
SkafteNicki Oct 12, 2020
0c90921
add test
SkafteNicki Oct 12, 2020
274da81
merge
SkafteNicki Oct 14, 2020
57461c8
update
SkafteNicki Oct 14, 2020
710690f
update
SkafteNicki Oct 14, 2020
0ab8e2e
merge
SkafteNicki Oct 22, 2020
851ff4b
more test
SkafteNicki Oct 22, 2020
7c750c6
more test
SkafteNicki Oct 22, 2020
75c4a84
pep8
SkafteNicki Oct 22, 2020
6f0eb65
fix doctest
SkafteNicki Oct 23, 2020
decfc62
pep8
SkafteNicki Oct 23, 2020
a789139
add clone method
SkafteNicki Oct 25, 2020
63d8ddb
add clone method
SkafteNicki Oct 25, 2020
0090786
merge
Nov 13, 2020
bd9fcbc
merge-2
Nov 13, 2020
3aadcb9
changelog
Nov 13, 2020
b972274
kwargs filtering and tests
Nov 13, 2020
bdf4744
pep8
Nov 13, 2020
4397534
fix test
Nov 13, 2020
a5ee82b
merge
SkafteNicki Dec 5, 2020
aad47d4
Merge remote-tracking branch 'upstream/release/1.2-dev' into metrics/…
SkafteNicki Dec 16, 2020
6986376
update docs
SkafteNicki Dec 18, 2020
4b48b7e
Update docs/source/metrics.rst
SkafteNicki Dec 30, 2020
c30e192
Merge branch 'release/1.2-dev' into metrics/metric_collection
SkafteNicki Jan 6, 2021
915bc0f
fix docs
SkafteNicki Jan 6, 2021
6b04f69
fix tests
SkafteNicki Jan 6, 2021
42e7567
Merge remote-tracking branch 'upstream/release/1.2-dev' into metrics/…
SkafteNicki Jan 7, 2021
af6d4c6
Apply suggestions from code review
SkafteNicki Jan 7, 2021
e9a59ce
merge
SkafteNicki Jan 7, 2021
2f8d3eb
fix docs
SkafteNicki Jan 7, 2021
1db5f09
fix doctest
SkafteNicki Jan 7, 2021
f18fba9
fix doctest
SkafteNicki Jan 8, 2021
52c4fb9
fix doctest
SkafteNicki Jan 8, 2021
34d488a
fix doctest
SkafteNicki Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `MetricCollection` ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))

- Added `.clone()` method to metrics ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

### Changed

Expand Down
71 changes: 69 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

.. note::

If using metrics in data parallel mode (dp), the metric update/logging should be done
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
or ``test``). This is due to metric states else being destroyed after each forward pass,
Expand All @@ -97,7 +98,6 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)


This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:

.. code-block:: python
Expand Down Expand Up @@ -129,7 +129,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
Metrics contain internal states that keep track of the data seen so far.
Do not mix metric states across training, validation and testing.
It is highly recommended to re-initialize the metric per mode as
shown in the examples above.
shown in the examples above. For easy initializing the same metric multiply
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
times, the ``.clone()`` method can be used:

.. code-block:: python
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
...
metric = pl.metrics.Accuracy()
self.train_acc = metric.clone()
self.val_acc = metric.clone()
self.test_acc = metric.clone()

.. note::

Expand Down Expand Up @@ -188,6 +198,63 @@ In practise this means that:
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated

****************
MetricCollection
****************

In many cases it is beneficial to evaluate the model output by multiple metrics.
In this case the `MetricCollection` class may come in handy. It accepts a sequence
of metrics and wraps theses into a single callable metric class, with the same
interface as any other metric.

Example:

.. code-block:: python
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([Accuracy(),
Precision(num_classes=3, average='macro'),
Recall(num_classes=3, average='macro')])
metric_collection(preds, target)
{'Accuracy': tensor(0.1250),
'Precision': tensor(0.0667),
'Recall': tensor(0.1111)}

Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule

.. code-block:: python

def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.valid_metrics = metrics.clone()

def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')

def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')

.. note::

`MetricCollection` as default assumes that all the metrics in the collection
have the same call signature. If this is not the case, input that should be
given to different metrics can given as keyword arguments to the collection.

.. autoclass:: pytorch_lightning.metrics.MetricCollection
:noindex:

**********
Metric API
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.metric import Metric, MetricCollection

from pytorch_lightning.metrics.classification import (
Accuracy,
Expand Down
113 changes: 109 additions & 4 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import deepcopy
Expand Down Expand Up @@ -57,6 +58,7 @@ class Metric(nn.Module, ABC):
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
"""

def __init__(
self,
compute_on_step: bool = True,
Expand All @@ -72,6 +74,7 @@ def __init__(
self.dist_sync_fn = dist_sync_fn
self._to_sync = True

self._update_signature = inspect.signature(self.update)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
Expand Down Expand Up @@ -119,7 +122,7 @@ def add_state(
"""
if (
not isinstance(default, torch.Tensor)
and not isinstance(default, list) # noqa: W503
and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError(
Expand Down Expand Up @@ -207,9 +210,11 @@ def wrapped_func(*args, **kwargs):
return self._computed

dist_sync_fn = self.dist_sync_fn
if (dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()):
if (
dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
# User provided a bool, so we assume DDP if available
dist_sync_fn = gather_all_tensors

Expand Down Expand Up @@ -249,6 +254,10 @@ def reset(self):
else:
setattr(self, attr, deepcopy(default))

def clone(self):
""" Make a copy of the metric """
return deepcopy(self)

def __getstate__(self):
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
Expand Down Expand Up @@ -291,3 +300,99 @@ def state_dict(self, *args, **kwargs):
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict


class MetricCollection(nn.ModuleDict):
"""
MetricCollection class can be used to chain metrics that have the same
call pattern into one single class.

Args:
metrics: One of the following

* list or tuple: if metrics are passed in as a list, will use the
metrics class name as key for output dict. Therefore, two metrics
of the same class cannot be chained this way.

* dict: if metrics are passed in as a dict, will use each key in the
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.

Example:

>>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
>>> metrics = MetricCollection([Accuracy(),
... Precision(num_classes=3, average='macro'),
... Recall(num_classes=3, average='macro')])
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}

>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
... 'macro_recall': Recall(num_classes=3, average='macro')})
>>> metrics(preds, target)
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}

"""
def __init__(self, metrics):
super().__init__()
if isinstance(metrics, dict):
# Check all values are metrics
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise ValueError(f'Value {metric} belonging to key {name}'
' is not an instance of `pl.metrics.Metric`')
self[name] = metric
elif isinstance(metrics, (tuple, list)):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(f'Input {metric} to `MetricCollection` is not a instance'
' of `pl.metrics.Metric`')
name = metric.__class__.__name__
if name in self:
raise ValueError(f'Encountered two metrics both named {name}')
self[name] = metric
else:
raise ValueError('Unknown input to MetricCollection.')

def _filter_kwargs(self, metric, **kwargs):
""" filter kwargs such that they match the update signature of the metric """
return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()}

def forward(self, *args, **kwargs):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""
Iteratively call forward for each metric. Positional arguments (*args) will
be passed to every metric in the collection, while keyword arguments (**kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""
Iteratively call update for each metric. Positional arguments (*args) will
be passed to every metric in the collection, while keyword arguments (**kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items():
m_kwargs = self._filter_kwargs(m, **kwargs)
m.update(*args, **m_kwargs)

def compute(self):
return {k: m.compute() for k, m in self.items()}

def reset(self):
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()

def clone(self):
""" Make a copy of the metric collection """
return deepcopy(self)

def persistent(self, mode: bool = True):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
m.persistent(mode)
Loading