diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 4fadfaa5071689..660dec028886e6 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -111,6 +111,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us It is highly recommended to re-initialize the metric per mode as shown in the examples above. +.. note:: + + Metric states will as default add their internal state to the models ``state_dict``. + To change this after initializing the metric the method ``.persistent(mode)`` can + be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. + ********************* Implementing a Metric ********************* diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index b716817427230d..75db7db12de59f 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -76,8 +76,9 @@ def __init__( self._forward_cache = None # initialize state - self._reductions = {} self._defaults = {} + self._persistent = {} + self._reductions = {} def add_state( self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True @@ -133,16 +134,10 @@ def add_state( "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" ) - if isinstance(default, torch.Tensor): - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - # persistent keyword is only supported in torch >= 1.6.0 - self.register_buffer(name, default, persistent=persistent) - else: - self.register_buffer(name, default) - else: - setattr(self, name, default) + setattr(self, name, default) self._defaults[name] = deepcopy(default) + self._persistent[name] = persistent self._reductions[name] = dist_reduce_fx @torch.jit.unused @@ -257,3 +252,36 @@ def __setstate__(self, state): self.__dict__.update(state) self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) + + def _apply(self, fn): + """ Overwrite _apply function such that we can also move metric states + to the correct device when `.to`, `.cuda`, etc methods are called + """ + self = super()._apply(fn) + # Also apply fn to metric states + for key in self._defaults.keys(): + current_val = getattr(self, key) + if isinstance(current_val, torch.Tensor): + setattr(self, key, fn(current_val)) + elif isinstance(current_val, Sequence): + setattr(self, key, [fn(cur_v) for cur_v in current_val]) + else: + raise TypeError('Expected metric state to be either a torch.Tensor' + f'or a list of torch.Tensor, but encountered {current_val}') + return self + + def persistent(self, mode: bool = True): + """ Method for post-init to change if metric states should be saved to + its state_dict + """ + for key in self._persistent.keys(): + self._persistent[key] = mode + + def state_dict(self, *args, **kwargs): + # Register metric states to be part of the state_dict + state_dict = super().state_dict() + for key in self._defaults.keys(): + if self._persistent[key]: + current_val = getattr(self, key) + state_dict.update({key: current_val}) + return state_dict diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 3c6938734be107..a35562327d7173 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,9 +1,11 @@ -import os +import pytest import torch + from pytorch_lightning import Trainer from pytorch_lightning.metrics import Metric from tests.base.boring_model import BoringModel +import tests.base.develop_utils as tutils class SumMetric(Metric): @@ -54,15 +56,19 @@ def test_metric_lightning_log(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() - self.metric = SumMetric() + self.metric_step = SumMetric() + self.metric_epoch = SumMetric() self.sum = 0.0 def training_step(self, batch, batch_idx): x = batch - self.metric(x.sum()) + self.metric_step(x.sum()) self.sum += x.sum() - self.log("sum", self.metric, on_epoch=True, on_step=False) - return self.step(x) + self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) + return {'loss': self.step(x), 'data': x} + + def training_epoch_end(self, outs): + self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum())) model = TestModel() model.val_dataloader = None @@ -78,7 +84,8 @@ def training_step(self, batch, batch_idx): trainer.fit(model) logged = trainer.logged_metrics - assert torch.allclose(torch.tensor(logged["sum"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) def test_scriptable(tmpdir):