Skip to content

Commit

Permalink
Fix load disparity between normal and hpc (#4526)
Browse files Browse the repository at this point in the history
* Add missing load functionality in hpc

* Add general file load for hpc

* Add mark in CHANGELOG

* Fix Typo Li**hg**tning

Co-authored-by: Rohit Gupta <[email protected]>

* Refactor line separation

Co-authored-by: Rohit Gupta <[email protected]>

* Fix entangled fixation commit

* Fix naming of restore_model_states

* Fix amp restore place

Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: chaton <[email protected]>
  • Loading branch information
3 people authored Nov 9, 2020
1 parent 23719e3 commit 41c9bee
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


- Fixed feature-lack in hpc load ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526))

## [1.0.5] - 2020-11-03

Expand Down
44 changes: 23 additions & 21 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# load model state
model = self.trainer.get_model()

# restore model and datamodule state
self.restore_model_state(model, checkpoint)

if on_gpu:
model.cuda(self.trainer.root_gpu)

# restore training state
self.restore_training_state(checkpoint)

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
"""

# give the datamodule a chance to load something
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
Expand All @@ -113,18 +127,6 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# restore the state_dict on the model
model.load_state_dict(checkpoint['state_dict'])

if on_gpu:
model.cuda(self.trainer.root_gpu)

# restore amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

# load training state (affects trainer only)
self.restore_training_state(checkpoint)

def restore_training_state(self, checkpoint):
"""
Restore trainer state.
Expand All @@ -147,6 +149,12 @@ def restore_training_state(self, checkpoint):
" where `model.ckpt` is your checkpoint file."
)

# restore amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

# restore callback states
self.trainer.on_load_checkpoint(checkpoint)

Expand Down Expand Up @@ -336,19 +344,13 @@ def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))

# load on CPU first
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage)

# load model state
model = self.trainer.get_model()

# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# restore amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])
# restore states from 'PyTorch-Lightning checkpoint' dictionary object
self.restore_model_state(model, checkpoint)

if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)
Expand Down

0 comments on commit 41c9bee

Please sign in to comment.