Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1 / 3] improvements to saving and loading callback state #6886

Merged
merged 44 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
89131c2
class name as key
awaelchli Apr 8, 2021
63fb983
string state identifier
awaelchli Apr 8, 2021
7dc218a
add legacy state loading
awaelchli Apr 8, 2021
04b588b
update test
awaelchli Apr 8, 2021
bb11e28
update tests
awaelchli Apr 8, 2021
271360c
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 15, 2021
20b66f0
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 16, 2021
f585a28
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 16, 2021
e1d518b
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 17, 2021
880066b
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 21, 2021
0259ecb
flake8
awaelchli Apr 21, 2021
d56e5e4
add test
awaelchli Apr 21, 2021
24a2cc8
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 22, 2021
81b5e36
Merge branch 'master' into bugfix/callback-state
carmocca Apr 22, 2021
72ba440
Apply suggestions from code review
awaelchli Apr 22, 2021
79d8568
improve test
awaelchli Apr 22, 2021
d9ea8c1
flake
awaelchli Apr 22, 2021
98f7fe6
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 22, 2021
68f571c
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 26, 2021
0851f0d
fix merge
awaelchli Jul 26, 2021
82d5658
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
334fd4a
use qualname
awaelchli Jul 26, 2021
090b169
Merge remote-tracking branch 'origin/bugfix/callback-state' into bugf…
awaelchli Jul 26, 2021
f144fd1
rename state_id
awaelchli Jul 26, 2021
6154986
fix diff
awaelchli Jul 26, 2021
0ec9bd2
update fx validator
awaelchli Jul 26, 2021
049f14d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
71f0bcc
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 26, 2021
3eca3c5
black
awaelchli Jul 26, 2021
ff190fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
a1b5b23
update test to ignore properties
awaelchli Jul 26, 2021
bffbd53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
b529489
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 27, 2021
b10a45a
update changelog
awaelchli Jul 27, 2021
d40b2cc
update test_fx_validator test
awaelchli Jul 27, 2021
a52ad31
add docs for state id
awaelchli Jul 28, 2021
a3ec571
update docs for state id
awaelchli Jul 28, 2021
140c71b
Update pytorch_lightning/callbacks/base.py
awaelchli Jul 28, 2021
d1b59db
Update tests/trainer/logging_/test_logger_connector.py
awaelchli Jul 28, 2021
8dcb54e
Update tests/checkpointing/test_model_checkpoint.py
awaelchli Jul 28, 2021
eea2dce
remove an empty line
awaelchli Jul 28, 2021
af59b07
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 28, 2021
e94f6df
fix import error
awaelchli Jul 28, 2021
302e724
move test
awaelchli Jul 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-


-
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


-
Expand All @@ -32,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



-
- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


-
Expand Down
17 changes: 16 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

import torch
from torch.optim import Optimizer
Expand All @@ -33,6 +33,21 @@ class Callback(abc.ABC):
Subclass this class and override any of the relevant hooks
"""

@property
def state_id(self) -> str:
"""
Identifier for the state of the callback. Used to store and retrieve a callback's state from the
checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to
provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of
multiple instances of that callback.
"""
return self.__class__.__qualname__
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def _legacy_state_id(self) -> Type["Callback"]:
"""State identifier for checkpoints saved prior to version 1.5.0."""
return type(self)

def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before configure sharded model"""

Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch

Expand Down Expand Up @@ -263,7 +263,7 @@ def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
return len(parameters) == 1 and parameters[0] != "args"

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
Expand All @@ -277,16 +277,15 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
else:
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[type(callback)] = state
callback_states[callback.state_id] = state
return callback_states

def on_load_checkpoint(self, checkpoint):
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint."""

# Todo: the `callback_states` are dropped with TPUSpawn as they
# can't be saved using `xm.save`
# https://github.com/pytorch/xla/issues/2773
callback_states = checkpoint.get("callbacks")
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")

if callback_states is None:
return
Expand All @@ -303,7 +302,7 @@ def on_load_checkpoint(self, checkpoint):
)

for callback in self.callbacks:
state = callback_states.get(type(callback))
state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if state:
state = deepcopy(state)
if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': PyTorch Lightning's version
'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
Expand Down
33 changes: 32 additions & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from unittest.mock import call, Mock

from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from tests.helpers import BoringModel


Expand Down Expand Up @@ -101,3 +102,33 @@ def configure_callbacks(self):
trainer_fn(model)
callbacks_after = trainer.callbacks.copy()
assert callbacks_after == callbacks_after_fit


class OldStatefulCallback(Callback):
def __init__(self, state):
self.state = state

@property
def state_id(self):
return type(self)

def on_save_checkpoint(self, *args):
return {"state": self.state}

def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state = callback_state["state"]


def test_resume_callback_state_saved_by_type(tmpdir):
"""Test that a legacy checkpoint that didn't use a state identifier before can still be loaded."""
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
trainer.fit(model)
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
assert ckpt_path.exists()

callback = OldStatefulCallback(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
trainer.fit(model)
assert callback.state == 111
3 changes: 2 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
Expand Down
10 changes: 4 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def on_validation_epoch_end(self):
assert chk["epoch"] == epoch + 1
assert chk["global_step"] == limit_train_batches * (epoch + 1)

mc_specific_data = chk["callbacks"][type(checkpoint)]
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -259,7 +259,7 @@ def _make_assertions(epoch, ix, version=""):
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk["global_step"] == expected_global_step

mc_specific_data = chk["callbacks"][type(checkpoint)]
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -857,9 +857,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):

assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]

ch_type = type(model_checkpoint)
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"]

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
Expand Down Expand Up @@ -1097,7 +1095,7 @@ def training_step(self, *args):
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts]
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]


Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
trainer.fit(model)

ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"][type(callback0)]
state1 = ckpt["callbacks"][type(callback1)]
state0 = ckpt["callbacks"]["StatefulCallback0"]
state1 = ckpt["callbacks"]["StatefulCallback1"]
assert "content0" in state0 and state0["content0"] == 0
assert "content1" in state1 and state1["content1"] == 1
assert type(checkpoint_callback) in ckpt["callbacks"]
assert "ModelCheckpoint" in ckpt["callbacks"]


def test_attach_model_callbacks():
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
from tests.models.test_hooks import get_members


def test_fx_validator(tmpdir):
funcs_name = sorted(f for f in dir(Callback) if not f.startswith("_"))
funcs_name = sorted(get_members(Callback))

callbacks_func = [
"on_before_backward",
Expand Down