Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric ddp bugfix #4482

Merged
merged 19 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))


- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))


### Changed

- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
Expand All @@ -49,6 +52,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed feature-lack in hpc load ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526))


- Fixed metrics states being overridden in ddp mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))


## [1.0.5] - 2020-11-03

### Added
Expand Down
6 changes: 6 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
It is highly recommended to re-initialize the metric per mode as
shown in the examples above.

.. note::

Metric states will as default add their internal state to the models ``state_dict``.
To change this after initializing the metric the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

*********************
Implementing a Metric
*********************
Expand Down
46 changes: 37 additions & 9 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def __init__(
self._forward_cache = None

# initialize state
self._reductions = {}
self._defaults = {}
self._persistent = {}
self._reductions = {}

def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
Expand Down Expand Up @@ -138,16 +139,10 @@ def add_state(
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
)

if isinstance(default, torch.Tensor):
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
# persistent keyword is only supported in torch >= 1.6.0
self.register_buffer(name, default, persistent=persistent)
else:
self.register_buffer(name, default)
else:
setattr(self, name, default)
setattr(self, name, default)

self._defaults[name] = deepcopy(default)
self._persistent[name] = persistent
self._reductions[name] = dist_reduce_fx

@torch.jit.unused
Expand Down Expand Up @@ -265,3 +260,36 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)

def _apply(self, fn):
""" Overwrite _apply function such that we can also move metric states
to the correct device when `.to`, `.cuda`, etc methods are called
"""
self = super()._apply(fn)
# Also apply fn to metric states
for key in self._defaults.keys():
current_val = getattr(self, key)
if isinstance(current_val, torch.Tensor):
setattr(self, key, fn(current_val))
elif isinstance(current_val, Sequence):
setattr(self, key, [fn(cur_v) for cur_v in current_val])
else:
raise TypeError('Expected metric state to be either a torch.Tensor'
f'or a list of torch.Tensor, but encountered {current_val}')
return self

def persistent(self, mode: bool = True):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
for key in self._persistent.keys():
self._persistent[key] = mode

def state_dict(self, *args, **kwargs):
# Register metric states to be part of the state_dict
state_dict = super().state_dict()
for key in self._defaults.keys():
if self._persistent[key]:
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict
19 changes: 13 additions & 6 deletions tests/metrics/test_metric_lightning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pytest

import torch

from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric
from tests.base.boring_model import BoringModel
import tests.base.develop_utils as tutils


class SumMetric(Metric):
Expand Down Expand Up @@ -54,15 +56,19 @@ def test_metric_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = SumMetric()
self.metric_step = SumMetric()
self.metric_epoch = SumMetric()
self.sum = 0.0

def training_step(self, batch, batch_idx):
x = batch
self.metric(x.sum())
self.metric_step(x.sum())
self.sum += x.sum()
self.log("sum", self.metric, on_epoch=True, on_step=False)
return self.step(x)
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
return {'loss': self.step(x), 'data': x}

def training_epoch_end(self, outs):
self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum()))

model = TestModel()
model.val_dataloader = None
Expand All @@ -78,7 +84,8 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)


def test_scriptable(tmpdir):
Expand Down