Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load ckpt path when model provided in validate/test/predict #8352

Merged
merged 40 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
03a8769
Change trainer loading behaviour for validate/test/predict
Jul 9, 2021
a943e33
Fix
Jul 9, 2021
40a3446
Fix/add tests
Jul 9, 2021
8c24ffd
remove
Jul 9, 2021
1879be7
Cleanups
Jul 12, 2021
3162ff7
Space
Jul 12, 2021
6dd61d6
cleanups
Jul 12, 2021
5772e17
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
b072868
Add CHANGELOG.md
Jul 12, 2021
de2738d
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
f2ee8b5
Move after setup
Jul 12, 2021
8659426
Cleanups on logic
Jul 12, 2021
84d20f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
9e367fd
Remve
Jul 12, 2021
b8ffc39
fix test
Jul 12, 2021
b02f35b
feedback
Jul 12, 2021
dbb03af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
1c7b9a1
Update pytorch_lightning/trainer/properties.py
Jul 12, 2021
444fb55
Feedback
Jul 12, 2021
4632bba
Same fix
Jul 12, 2021
e92b757
Same fix
Jul 12, 2021
66bea8e
Add test for behaviour, modify based on feedback
Jul 12, 2021
0139a19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
d48d916
Wording
Jul 12, 2021
100d73b
Apply suggestions from code review
Jul 12, 2021
f3f92a5
Cleanup docs
Jul 12, 2021
2849d0b
Update pytorch_lightning/trainer/trainer.py
Jul 12, 2021
f53c896
feedback
Jul 12, 2021
ebc713b
Fixes to test API
Jul 12, 2021
76e22c2
Add carlos description
Jul 12, 2021
9c2a0de
Move logic further
Jul 13, 2021
e5c104c
Merge branch 'master' into feat/ckpt_load
Jul 13, 2021
be07eec
Move checkpoint connector logic
Jul 13, 2021
a9538d6
revert
Jul 13, 2021
e5a0dba
Remove 4 as this is a dupe now
Jul 27, 2021
5136b61
Merge branch 'master' into feat/ckpt_load
Jul 27, 2021
9867b38
fix
Jul 27, 2021
3c16048
set to best
Jul 27, 2021
b2ba6db
Fix location
Jul 28, 2021
33f8917
Merge branch 'master' into feat/ckpt_load
Jul 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def validate(

ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
When the model is given as argument, we load the ckpt path.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

verbose: If True, prints the validation results.

Expand Down Expand Up @@ -579,7 +579,6 @@ def validate(
if dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`')

model_provided = model is not None
model = model or self.lightning_module
if model is None:
raise MisconfigurationException(
Expand All @@ -589,8 +588,7 @@ def validate(
# links data to the trainer
self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.validated_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

# run validate
results = self._run(model)
Expand Down Expand Up @@ -621,7 +619,7 @@ def test(

ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
When the model is given as argument, we load the ckpt path.

verbose: If True, prints the test results.

Expand Down Expand Up @@ -654,7 +652,6 @@ def test(
if dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`')

model_provided = model is not None
model = model or self.lightning_module
if model is None:
raise MisconfigurationException(
Expand All @@ -664,8 +661,7 @@ def test(
# links data to the trainer
self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.tested_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None)

# run test
results = self._run(model)
Expand Down Expand Up @@ -699,9 +695,9 @@ def predict(
return_predictions: Whether to return predictions.
``True`` by default except when an accelerator that spawns processes is used (not supported).

ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict.
ckpt_path: Either ``best`` or path to the checkpoint you wish to predict.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
When the model is given as argument, we load the ckpt path.

Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
Expand All @@ -725,7 +721,6 @@ def predict(
if dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`')

model_provided = model is not None
model = model or self.lightning_module
if model is None:
raise MisconfigurationException(
Expand All @@ -735,8 +730,7 @@ def predict(
# links data to the trainer
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.predicted_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.predicted_ckpt_path = self.__set_ckpt_path(ckpt_path, model_provided=model is not None)

results = self._run(model)

Expand Down Expand Up @@ -807,6 +801,15 @@ def tune(

return result

@property
def ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
Expand Down Expand Up @@ -835,6 +838,15 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT,
# restore callback states
self.checkpoint_connector.restore_callbacks()

if self.ckpt_path:
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

rank_zero_info(f"Loading checkpoint from {self.ckpt_path}")
self.checkpoint_connector.restore_model_weights(self.ckpt_path)

self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

Expand Down Expand Up @@ -1059,13 +1071,13 @@ def _run_sanity_check(self, ref_model):
# restore the previous stage when the sanity check if finished
self.state.stage = stage

def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
if ckpt_path is None:
def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool) -> Optional[str]:
if model_provided and ckpt_path is None:
return

fn = self.state.fn.value

if ckpt_path == 'best':
if model_provided and ckpt_path == 'best':
# if user requests the best checkpoint but we don't have it, error
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
Expand All @@ -1084,13 +1096,6 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)

# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

self.checkpoint_connector.restore_model_weights(ckpt_path)
return ckpt_path

def _call_setup_hook(self, model: 'pl.LightningModule') -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,9 @@ def predict_step(self, batch, *_):
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path

trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path


def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""
Expand Down