diff --git a/CHANGELOG.md b/CHANGELOG.md index 4495a8f961232..da64a70c3231c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,7 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - +- Fixed feature-lack in hpc load ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526)) ## [1.0.5] - 2020-11-03 diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3bf76e4c30630..3b44ce96c02ad 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -103,6 +103,20 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # load model state model = self.trainer.get_model() + # restore model and datamodule state + self.restore_model_state(model, checkpoint) + + if on_gpu: + model.cuda(self.trainer.root_gpu) + + # restore training state + self.restore_training_state(checkpoint) + + def restore_model_state(self, model: LightningModule, checkpoint) -> None: + """ + Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object + """ + # give the datamodule a chance to load something if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) @@ -113,18 +127,6 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # restore the state_dict on the model model.load_state_dict(checkpoint['state_dict']) - if on_gpu: - model.cuda(self.trainer.root_gpu) - - # restore amp scaling - if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: - self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: - amp.load_state_dict(checkpoint['amp_scaling_state']) - - # load training state (affects trainer only) - self.restore_training_state(checkpoint) - def restore_training_state(self, checkpoint): """ Restore trainer state. @@ -147,6 +149,12 @@ def restore_training_state(self, checkpoint): " where `model.ckpt` is your checkpoint file." ) + # restore amp scaling + if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: + self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: + amp.load_state_dict(checkpoint['amp_scaling_state']) + # restore callback states self.trainer.on_load_checkpoint(checkpoint) @@ -336,19 +344,13 @@ def hpc_load(self, folderpath, on_gpu): filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath)) # load on CPU first - checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) + checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage) # load model state model = self.trainer.get_model() - # load the state_dict on the model automatically - model.load_state_dict(checkpoint['state_dict']) - - # restore amp scaling - if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: - self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: - amp.load_state_dict(checkpoint['amp_scaling_state']) + # restore states from 'PyTorch-Lightning checkpoint' dictionary object + self.restore_model_state(model, checkpoint) if self.trainer.root_gpu is not None: model.cuda(self.trainer.root_gpu)