Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1 / 3] improvements to saving and loading callback state #6886

Merged
merged 44 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
89131c2
class name as key
awaelchli Apr 8, 2021
63fb983
string state identifier
awaelchli Apr 8, 2021
7dc218a
add legacy state loading
awaelchli Apr 8, 2021
04b588b
update test
awaelchli Apr 8, 2021
bb11e28
update tests
awaelchli Apr 8, 2021
271360c
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 15, 2021
20b66f0
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 16, 2021
f585a28
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 16, 2021
e1d518b
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 17, 2021
880066b
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 21, 2021
0259ecb
flake8
awaelchli Apr 21, 2021
d56e5e4
add test
awaelchli Apr 21, 2021
24a2cc8
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 22, 2021
81b5e36
Merge branch 'master' into bugfix/callback-state
carmocca Apr 22, 2021
72ba440
Apply suggestions from code review
awaelchli Apr 22, 2021
79d8568
improve test
awaelchli Apr 22, 2021
d9ea8c1
flake
awaelchli Apr 22, 2021
98f7fe6
Merge branch 'master' into bugfix/callback-state
awaelchli Apr 22, 2021
68f571c
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 26, 2021
0851f0d
fix merge
awaelchli Jul 26, 2021
82d5658
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
334fd4a
use qualname
awaelchli Jul 26, 2021
090b169
Merge remote-tracking branch 'origin/bugfix/callback-state' into bugf…
awaelchli Jul 26, 2021
f144fd1
rename state_id
awaelchli Jul 26, 2021
6154986
fix diff
awaelchli Jul 26, 2021
0ec9bd2
update fx validator
awaelchli Jul 26, 2021
049f14d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
71f0bcc
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 26, 2021
3eca3c5
black
awaelchli Jul 26, 2021
ff190fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
a1b5b23
update test to ignore properties
awaelchli Jul 26, 2021
bffbd53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
b529489
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 27, 2021
b10a45a
update changelog
awaelchli Jul 27, 2021
d40b2cc
update test_fx_validator test
awaelchli Jul 27, 2021
a52ad31
add docs for state id
awaelchli Jul 28, 2021
a3ec571
update docs for state id
awaelchli Jul 28, 2021
140c71b
Update pytorch_lightning/callbacks/base.py
awaelchli Jul 28, 2021
d1b59db
Update tests/trainer/logging_/test_logger_connector.py
awaelchli Jul 28, 2021
8dcb54e
Update tests/checkpointing/test_model_checkpoint.py
awaelchli Jul 28, 2021
eea2dce
remove an empty line
awaelchli Jul 28, 2021
af59b07
Merge branch 'master' into bugfix/callback-state
awaelchli Jul 28, 2021
e94f6df
fix import error
awaelchli Jul 28, 2021
302e724
move test
awaelchli Jul 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

from pytorch_lightning.core.lightning import LightningModule

Expand All @@ -29,6 +29,14 @@ class Callback(abc.ABC):
Subclass this class and override any of the relevant hooks
"""

@property
def state_identifier(self) -> str:
return self.__class__.__name__
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

@property
def _legacy_state_identifier(self) -> Type:
return type(self)

def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None:
"""Called before configure sharded model"""

Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from abc import ABC
from copy import deepcopy
from distutils.version import LooseVersion
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Type

Expand Down Expand Up @@ -243,7 +244,7 @@ def __is_old_signature(fn: Callable) -> bool:
return True
return False

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
Expand All @@ -257,18 +258,22 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
else:
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[type(callback)] = state
callback_states[callback.state_identifier] = state
return callback_states

def on_load_checkpoint(self, checkpoint):
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint."""
callback_states = checkpoint.get('callbacks')
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
version = checkpoint.get('pytorch-lightning_version')
# Todo: the `callback_states` are dropped with TPUSpawn as they
# can't be saved using `xm.save`
# https://github.com/pytorch/xla/issues/2773
if callback_states is not None:
for callback in self.callbacks:
state = callback_states.get(type(callback))
state = (
callback_states.get(callback.state_identifier)
or callback_states.get(callback._legacy_state_identifier)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if state:
state = deepcopy(state)
callback.on_load_checkpoint(state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': PyTorch Lightning's version
'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss')
Expand Down
1 change: 1 addition & 0 deletions tests/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"1.2.5",
"1.2.6",
"1.2.7",
"1.2.8",
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
]
)
def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
Expand Down
8 changes: 4 additions & 4 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def configure_optimizers(self):
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == limit_train_batches * (epoch + 1)

mc_specific_data = chk['callbacks'][type(checkpoint)]
mc_specific_data = chk['callbacks']["ModelCheckpoint"]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
Expand Down Expand Up @@ -251,7 +251,7 @@ def configure_optimizers(self):
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)

mc_specific_data = chk['callbacks'][type(checkpoint)]
mc_specific_data = chk['callbacks']["ModelCheckpoint"]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
Expand Down Expand Up @@ -840,7 +840,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
ckpt_last = torch.load(path_last)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))

ch_type = type(model_checkpoint)
ch_type = "ModelCheckpoint"
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# it is easier to load the model objects than to iterate over the raw dict of tensors
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def training_step(self, *args):
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts]
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]


Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
trainer.fit(model)

ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"][type(callback0)]
state1 = ckpt["callbacks"][type(callback1)]
state0 = ckpt["callbacks"]["StatefulCallback0"]
state1 = ckpt["callbacks"]["StatefulCallback1"]
assert "content0" in state0 and state0["content0"] == 0
assert "content1" in state1 and state1["content1"] == 1
assert type(checkpoint_callback) in ckpt["callbacks"]
assert "ModelCheckpoint" in ckpt["callbacks"]


def test_attach_model_callbacks():
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_dataloader(self):

def test_call_back_validator(tmpdir):

funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')])
funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_') and callable(getattr(Callback, f))])

callbacks_func = [
'on_after_backward',
Expand Down