Skip to content

Commit

Permalink
Merge branch 'feature/migration' into bugfix/callback-state
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 21, 2021
2 parents d0fbc13 + 792e3b3 commit 5f33eff
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 45 deletions.
6 changes: 6 additions & 0 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.migration.migrations import upgrade_checkpoint
from pytorch_lightning.utilities.parsing import parse_class_init_keys

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -134,6 +135,9 @@ def load_from_checkpoint(
else:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# convert legacy checkpoints to the new format
checkpoint = upgrade_checkpoint(checkpoint)

if hparams_file is not None:
extension = hparams_file.split('.')[-1]
if extension.lower() == 'csv':
Expand All @@ -148,6 +152,7 @@ def load_from_checkpoint(
# overwrite hparams by the given file
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

# TODO: make this a migration:
# for past checkpoint need to add the new key
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
Expand All @@ -171,6 +176,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:

# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
# TODO: make this a migration:
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint."""
callback_states = checkpoint.get('callbacks')
version = checkpoint.get('pytorch-lightning_version')
# Todo: the `callback_states` are dropped with TPUSpawn as they
# can't be saved using `xm.save`
# https://github.com/pytorch/xla/issues/2773
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,11 @@ def _gpus_allowed_type(x) -> Union[int, str]:
return int(x)


def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
# unused, but here for backward compatibility with old checkpoints that need to be able to
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
pass
# def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
# # unused, but here for backward compatibility with old checkpoints that need to be able to
# # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
# # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
# pass


def _int_or_float_type(x) -> Union[int, float]:
Expand Down
56 changes: 30 additions & 26 deletions pytorch_lightning/utilities/migration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,16 @@
"1.2.6",
"1.2.7",
"1.2.8",
"1.2.9",
"1.3.0rc0",
"1.3.0rc1",
pytorch_lightning.__version__,
]

if pytorch_lightning.__version__ not in version_history:
version_history.append(pytorch_lightning.__version__)

def default_upgrade_rule(checkpoint):

def default_migration(checkpoint):
""" Upgrades to the next version by only replacing the current version with the new one. """
# TODO: find more elegant version for the if below
current = get_version(checkpoint)
Expand All @@ -110,6 +113,9 @@ def default_upgrade_rule(checkpoint):
return checkpoint


all_migrations = dict((ver, default_migration) for ver in version_history)


def get_version(checkpoint: dict) -> str:
return checkpoint["pytorch-lightning_version"]

Expand All @@ -121,27 +127,25 @@ def set_version(checkpoint: dict, version: str):
class Migration:
""" Decorator for a function that upgrades a checkpoint from one version to the next. """

all_migrations = dict.fromkeys(version_history, [default_upgrade_rule])

def __init__(self, requires: Optional[str]):
self.required_version = requires

def __call__(self, fn: callable) -> callable:
@wraps(fn)
def wrapper(ckpt):
current_version = get_version(ckpt)
if self.required_version and current_version != self.required_version:
log.error(f"skipping, {current_version}")
return ckpt
new_ckpt = fn(ckpt)
return new_ckpt

self.all_migrations[self.required_version].insert(0, wrapper)
return wrapper

@staticmethod
def migrate(checkpoint: dict) -> dict:
for version_migrations in Migration.all_migrations.values():
for migration in version_migrations:
checkpoint = migration(checkpoint)
return checkpoint
def __init__(self, target: Optional[str]):
self.target_version = target

def __call__(self, upgrade_fn: callable) -> callable:
if getattr(upgrade_fn, "_migration_registered", False) and all_migrations[self.target_version] != default_migration:
raise ValueError(
f"Tried to register a new migration {upgrade_fn.__name__}, but"
f" there is already a migration for version {self.target_version}:"
f" {all_migrations[self.target_version].__name__}"
)
all_migrations[self.target_version] = upgrade_fn
upgrade_fn._migration_registered = True
return upgrade_fn


def upgrade_checkpoint(checkpoint: dict) -> dict:
for migration in all_migrations.values():
if migration is None:
checkpoint = default_migration(checkpoint)
checkpoint = migration(checkpoint)
return checkpoint

27 changes: 14 additions & 13 deletions pytorch_lightning/utilities/migration/migrations.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import torch

from pytorch_lightning.utilities.migration.base import Migration, get_version
from pytorch_lightning.utilities.migration.base import Migration, upgrade_checkpoint
from pytorch_lightning.utilities.migration.patch import pl_legacy_patch


@Migration(requires="1.2.7")
def upgrade_callback_names(checkpoint: dict) -> dict:
if "callbacks" not in checkpoint:
return checkpoint
checkpoint["callbacks"] = reversed(checkpoint["callbacks"])
print(get_version(checkpoint))
@Migration(target="1.2.8")
def upgrade_something_else(checkpoint: dict) -> dict:
return checkpoint


@Migration(requires="1.2.8")
def upgrade_something_else(checkpoint: dict) -> dict:
@Migration(target="1.2.9")
def upgrade_callback_state_identifiers(checkpoint):
if "callbacks" not in checkpoint:
return
callbacks = checkpoint["callbacks"]
checkpoint["callbacks"] = dict((callback_type.__name__, state) for callback_type, state in callbacks.items())
return checkpoint


if __name__ == "__main__":
with pl_legacy_patch():
checkpoint = torch.load("gpus-default-legacy.ckpt")
ckpt = torch.load("gpus-default-legacy.ckpt")
# checkpoint = torch.load("example.ckpt")
# checkpoint["pytorch-lightning_version"] = "1.2.6"
# getattr()
checkpoint = Migration.migrate(checkpoint)
print(checkpoint)

ckpt = upgrade_checkpoint(ckpt)
from pprint import pprint
pprint(ckpt)

0 comments on commit 5f33eff

Please sign in to comment.