Skip to content

Commit

Permalink
[1 / 3] improvements to saving and loading callback state (#6886)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 28, 2021
1 parent 0c0b24c commit 8c27fa7
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 26 deletions.
7 changes: 2 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-


-
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


-
Expand All @@ -32,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



-
- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


-
Expand Down
17 changes: 16 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

import torch
from torch.optim import Optimizer
Expand All @@ -33,6 +33,21 @@ class Callback(abc.ABC):
Subclass this class and override any of the relevant hooks
"""

@property
def state_id(self) -> str:
"""
Identifier for the state of the callback. Used to store and retrieve a callback's state from the
checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to
provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of
multiple instances of that callback.
"""
return self.__class__.__qualname__

@property
def _legacy_state_id(self) -> Type["Callback"]:
"""State identifier for checkpoints saved prior to version 1.5.0."""
return type(self)

def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before configure sharded model"""

Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch

Expand Down Expand Up @@ -247,7 +247,7 @@ def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
return len(parameters) == 1 and parameters[0] != "args"

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
Expand All @@ -261,16 +261,15 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
else:
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[type(callback)] = state
callback_states[callback.state_id] = state
return callback_states

def on_load_checkpoint(self, checkpoint):
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint."""

# Todo: the `callback_states` are dropped with TPUSpawn as they
# can't be saved using `xm.save`
# https://github.com/pytorch/xla/issues/2773
callback_states = checkpoint.get("callbacks")
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")

if callback_states is None:
return
Expand All @@ -287,7 +286,7 @@ def on_load_checkpoint(self, checkpoint):
)

for callback in self.callbacks:
state = callback_states.get(type(callback))
state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id))
if state:
state = deepcopy(state)
if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': PyTorch Lightning's version
'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
Expand Down
33 changes: 32 additions & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from unittest.mock import call, Mock

from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from tests.helpers import BoringModel


Expand Down Expand Up @@ -101,3 +102,33 @@ def configure_callbacks(self):
trainer_fn(model)
callbacks_after = trainer.callbacks.copy()
assert callbacks_after == callbacks_after_fit


class OldStatefulCallback(Callback):
def __init__(self, state):
self.state = state

@property
def state_id(self):
return type(self)

def on_save_checkpoint(self, *args):
return {"state": self.state}

def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state = callback_state["state"]


def test_resume_callback_state_saved_by_type(tmpdir):
"""Test that a legacy checkpoint that didn't use a state identifier before can still be loaded."""
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
trainer.fit(model)
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
assert ckpt_path.exists()

callback = OldStatefulCallback(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
trainer.fit(model)
assert callback.state == 111
3 changes: 2 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state

# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
Expand Down
10 changes: 4 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def on_validation_epoch_end(self):
assert chk["epoch"] == epoch + 1
assert chk["global_step"] == limit_train_batches * (epoch + 1)

mc_specific_data = chk["callbacks"][type(checkpoint)]
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -259,7 +259,7 @@ def _make_assertions(epoch, ix, version=""):
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk["global_step"] == expected_global_step

mc_specific_data = chk["callbacks"][type(checkpoint)]
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -857,9 +857,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):

assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]

ch_type = type(model_checkpoint)
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"]

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
Expand Down Expand Up @@ -1097,7 +1095,7 @@ def training_step(self, *args):
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts]
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]


Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
trainer.fit(model)

ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"][type(callback0)]
state1 = ckpt["callbacks"][type(callback1)]
state0 = ckpt["callbacks"]["StatefulCallback0"]
state1 = ckpt["callbacks"]["StatefulCallback1"]
assert "content0" in state0 and state0["content0"] == 0
assert "content1" in state1 and state1["content1"] == 1
assert type(checkpoint_callback) in ckpt["callbacks"]
assert "ModelCheckpoint" in ckpt["callbacks"]


def test_attach_model_callbacks():
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
from tests.models.test_hooks import get_members


def test_fx_validator(tmpdir):
funcs_name = sorted(f for f in dir(Callback) if not f.startswith("_"))
funcs_name = sorted(get_members(Callback))

callbacks_func = [
"on_before_backward",
Expand Down

0 comments on commit 8c27fa7

Please sign in to comment.