Skip to content

Commit

Permalink
Metric ddp bugfix (#4482)
Browse files Browse the repository at this point in the history
* changes

* fix spelling

* small note

* trying to fix ddp test

* fix ddp

* fix for test

* suggestion

* CHANGELOG

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: chaton <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>

(cherry picked from commit 465ec75)
  • Loading branch information
SkafteNicki authored and SeanNaren committed Nov 10, 2020
1 parent 8d547c6 commit 4441330
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 15 deletions.
6 changes: 6 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
It is highly recommended to re-initialize the metric per mode as
shown in the examples above.

.. note::

Metric states will as default add their internal state to the models ``state_dict``.
To change this after initializing the metric the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

*********************
Implementing a Metric
*********************
Expand Down
46 changes: 37 additions & 9 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def __init__(
self._forward_cache = None

# initialize state
self._reductions = {}
self._defaults = {}
self._persistent = {}
self._reductions = {}

def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
Expand Down Expand Up @@ -133,16 +134,10 @@ def add_state(
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
)

if isinstance(default, torch.Tensor):
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
# persistent keyword is only supported in torch >= 1.6.0
self.register_buffer(name, default, persistent=persistent)
else:
self.register_buffer(name, default)
else:
setattr(self, name, default)
setattr(self, name, default)

self._defaults[name] = deepcopy(default)
self._persistent[name] = persistent
self._reductions[name] = dist_reduce_fx

@torch.jit.unused
Expand Down Expand Up @@ -257,3 +252,36 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)

def _apply(self, fn):
""" Overwrite _apply function such that we can also move metric states
to the correct device when `.to`, `.cuda`, etc methods are called
"""
self = super()._apply(fn)
# Also apply fn to metric states
for key in self._defaults.keys():
current_val = getattr(self, key)
if isinstance(current_val, torch.Tensor):
setattr(self, key, fn(current_val))
elif isinstance(current_val, Sequence):
setattr(self, key, [fn(cur_v) for cur_v in current_val])
else:
raise TypeError('Expected metric state to be either a torch.Tensor'
f'or a list of torch.Tensor, but encountered {current_val}')
return self

def persistent(self, mode: bool = True):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
for key in self._persistent.keys():
self._persistent[key] = mode

def state_dict(self, *args, **kwargs):
# Register metric states to be part of the state_dict
state_dict = super().state_dict()
for key in self._defaults.keys():
if self._persistent[key]:
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict
19 changes: 13 additions & 6 deletions tests/metrics/test_metric_lightning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pytest

import torch

from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric
from tests.base.boring_model import BoringModel
import tests.base.develop_utils as tutils


class SumMetric(Metric):
Expand Down Expand Up @@ -54,15 +56,19 @@ def test_metric_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = SumMetric()
self.metric_step = SumMetric()
self.metric_epoch = SumMetric()
self.sum = 0.0

def training_step(self, batch, batch_idx):
x = batch
self.metric(x.sum())
self.metric_step(x.sum())
self.sum += x.sum()
self.log("sum", self.metric, on_epoch=True, on_step=False)
return self.step(x)
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
return {'loss': self.step(x), 'data': x}

def training_epoch_end(self, outs):
self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum()))

model = TestModel()
model.val_dataloader = None
Expand All @@ -78,7 +84,8 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)


def test_scriptable(tmpdir):
Expand Down

0 comments on commit 4441330

Please sign in to comment.