Skip to content

Commit

Permalink
deconflict
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Aug 15, 2020
2 parents 8aad173 + ba9c51f commit 26aac08
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 23 deletions.
38 changes: 21 additions & 17 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,35 +145,39 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, *cls
cls_init_args_name = inspect.signature(cls).parameters.keys()
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
cls_kwargs_loaded = {}

# add some back compatibility, the actual one shall be last
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
if hparam_key in checkpoint:
model_args.update(checkpoint[hparam_key])
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
if _old_hparam_key in checkpoint:
cls_kwargs_loaded.update(checkpoint[_old_hparam_key])

model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
# 2. Try to restore model hparams from checkpoint using the new key
_new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY
cls_kwargs_loaded.update(checkpoint[_new_hparam_key])

args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
# 3. Ensure that `cls_kwargs_old` has the right type
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))

if args_name == 'kwargs':
# in case the class cannot take any extra argument filter only the possible
cls_kwargs.update(**model_args)
elif args_name:
if args_name in cls_init_args_name:
cls_kwargs.update({args_name: model_args})
# 4. Update cls_kwargs_new with cls_kwargs_old
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
if args_name and args_name in cls_init_args_name:
cls_kwargs.update({args_name: cls_kwargs_loaded})
else:
cls_args = (model_args,) + cls_args
cls_kwargs.update(cls_kwargs_loaded)

if not cls_spec.varkw:
# filter kwargs according to class init unless it allows any argument via kwargs
cls_kwargs = {k: v for k, v in cls_kwargs.items() if k in cls_init_args_name}
cls_kwargs_extra = {k: v for k, v in cls_kwargs.items() if k in cls_init_args_name}

# prevent passing positional arguments if class does not accept any
if len(cls_spec.args) <= 1 and not cls_spec.varargs and not cls_spec.kwonlyargs:
cls_args, cls_kwargs = [], {}
_cls_args_new, _cls_kwargs_new = [], {}
else:
_cls_args_new, _cls_kwargs_new = cls_args, cls_kwargs_extra

model = cls(*_cls_args_new, **_cls_kwargs_new)

model = cls(*cls_args, **cls_kwargs)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'], strict=strict)

Expand Down
5 changes: 3 additions & 2 deletions tests/base/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(
out_features: int = 10,
hidden_dim: int = 1000,
b1: float = 0.5,
b2: float = 0.999
b2: float = 0.999,
save_hparams=True,
):
# init superclass
super().__init__()
Expand Down Expand Up @@ -126,7 +127,7 @@ def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0
hidden_dim=1000,
b1=0.5,
b2=0.999,
save_hparams=True
save_hparams=True,
)

if continue_training:
Expand Down
42 changes: 38 additions & 4 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.base import EvalModelTemplate
Expand Down Expand Up @@ -175,20 +176,23 @@ def test_load_model_from_checkpoint(tmpdir):

# load last checkpoint
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]

# Since `EvalModelTemplate` has `_save_hparams = True` by default, check that ckpt has hparams
ckpt = torch.load(last_checkpoint)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), 'module_arguments missing from checkpoints'

# Ensure that model can be correctly restored from checkpoint
pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint)

# test that hparams loaded correctly
for k, v in hparams.items():
assert getattr(pretrained_model, k) == v

# assert weights are the same
for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()):
assert torch.all(torch.eq(old_p, new_p)), 'loaded weights are not the same as the saved weights'

# Check `test` on pretrained model:
new_trainer = Trainer(**trainer_options)
new_trainer.test(pretrained_model)

# test we have good test accuracy
tutils.assert_ok_model_acc(new_trainer)


Expand Down Expand Up @@ -419,6 +423,36 @@ def __init__(self):
)


def test_load_model_from_checkpoint_extra_args(tmpdir):
"""Check that model args can be passed/changed when `load_from_checkpoint` is called."""

model = EvalModelTemplate()

trainer_options = dict(
progress_bar_refresh_rate=0,
max_epochs=2,
limit_train_batches=0.4,
limit_val_batches=0.2,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
default_root_dir=tmpdir,
)

# Fit model
trainer = Trainer(**trainer_options)
trainer.fit(model)

# Load last checkpoint
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint, b1=0.5, b2=0.888)

# Assert that model args were changed accordingly
# Assert that model weights did not change
assert pretrained_model.b1 == 0.5 # `b1` arg did not change
assert pretrained_model.b2 == 0.888 # `b2` arg changed
for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()):
assert torch.all(torch.eq(old_p, new_p)), 'loaded weights are not the same as the saved weights'


def test_model_pickle(tmpdir):
model = EvalModelTemplate()
pickle.dumps(model)
Expand Down

0 comments on commit 26aac08

Please sign in to comment.