Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load ckpt path when model provided in validate/test/predict #8352

Merged
merged 40 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
03a8769
Change trainer loading behaviour for validate/test/predict
Jul 9, 2021
a943e33
Fix
Jul 9, 2021
40a3446
Fix/add tests
Jul 9, 2021
8c24ffd
remove
Jul 9, 2021
1879be7
Cleanups
Jul 12, 2021
3162ff7
Space
Jul 12, 2021
6dd61d6
cleanups
Jul 12, 2021
5772e17
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
b072868
Add CHANGELOG.md
Jul 12, 2021
de2738d
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
f2ee8b5
Move after setup
Jul 12, 2021
8659426
Cleanups on logic
Jul 12, 2021
84d20f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
9e367fd
Remve
Jul 12, 2021
b8ffc39
fix test
Jul 12, 2021
b02f35b
feedback
Jul 12, 2021
dbb03af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
1c7b9a1
Update pytorch_lightning/trainer/properties.py
Jul 12, 2021
444fb55
Feedback
Jul 12, 2021
4632bba
Same fix
Jul 12, 2021
e92b757
Same fix
Jul 12, 2021
66bea8e
Add test for behaviour, modify based on feedback
Jul 12, 2021
0139a19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
d48d916
Wording
Jul 12, 2021
100d73b
Apply suggestions from code review
Jul 12, 2021
f3f92a5
Cleanup docs
Jul 12, 2021
2849d0b
Update pytorch_lightning/trainer/trainer.py
Jul 12, 2021
f53c896
feedback
Jul 12, 2021
ebc713b
Fixes to test API
Jul 12, 2021
76e22c2
Add carlos description
Jul 12, 2021
9c2a0de
Move logic further
Jul 13, 2021
e5c104c
Merge branch 'master' into feat/ckpt_load
Jul 13, 2021
be07eec
Move checkpoint connector logic
Jul 13, 2021
a9538d6
revert
Jul 13, 2021
e5a0dba
Remove 4 as this is a dupe now
Jul 27, 2021
5136b61
Merge branch 'master' into feat/ckpt_load
Jul 27, 2021
9867b38
fix
Jul 27, 2021
3c16048
set to best
Jul 27, 2021
b2ba6db
Fix location
Jul 28, 2021
33f8917
Merge branch 'master' into feat/ckpt_load
Jul 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto'` ([#7808](https://github.com/PyTorchLightning/pytorch-lightning/pull/7808))


- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))


### Changed


Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/test_set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ To run the test set after training completes, use this method.
trainer.test()

# (2) don't load a checkpoint, instead use the model with the latest weights
trainer.test(ckpt_path=None)
trainer.test(model)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

# (3) test using a specific checkpoint
trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ class TrainerProperties(ABC):
validate_loop: EvaluationLoop
test_loop: EvaluationLoop
predict_loop: PredictionLoop

# .validate() and .test() set this when they load a checkpoint
validated_ckpt_path: Optional[str] = None
tested_ckpt_path: Optional[str] = None
predicted_ckpt_path: Optional[str] = None
"""
Accelerator properties
"""
Expand Down Expand Up @@ -548,6 +553,15 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

"""
Logging properties
"""
Expand Down
68 changes: 39 additions & 29 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,15 +458,10 @@ def _setup_on_init(
self.test_dataloaders = None
self.val_dataloaders = None

# .validate() and .test() set this when they load a checkpoint
self.validated_ckpt_path = None
self.tested_ckpt_path = None

# when true, print evaluation results in .validate() and .test()
self.verbose_evaluate = True

self.num_predict_batches = []
self.predicted_ckpt_path = None

def fit(
self,
Expand Down Expand Up @@ -528,7 +523,7 @@ def validate(
self,
model: Optional['pl.LightningModule'] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = 'best',
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
val_dataloaders=None, # noqa TODO: remove with 1.6
Expand All @@ -543,8 +538,8 @@ def validate(
or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples.

ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
If ``None``, use the current weights of the model if provided, or the best model from ``trainer.fit``.
When the model and the ckpt path are passed as arguments, we load the ckpt path.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

verbose: If True, prints the validation results.

Expand Down Expand Up @@ -589,8 +584,9 @@ def validate(
# links data to the trainer
self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.validated_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

# run validate
results = self._run(model)
Expand All @@ -604,7 +600,7 @@ def test(
self,
model: Optional['pl.LightningModule'] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = 'best',
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
test_dataloaders=None, # noqa TODO: remove with 1.6
Expand All @@ -620,8 +616,8 @@ def test(
or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples.

ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
If ``None``, use the current weights of the model if provided, or the best model from ``trainer.fit``.
When the model and the ckpt path are passed as arguments, we load the ckpt path.

verbose: If True, prints the test results.

Expand Down Expand Up @@ -664,8 +660,9 @@ def test(
# links data to the trainer
self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.tested_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

# run test
results = self._run(model)
Expand All @@ -681,7 +678,7 @@ def predict(
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = 'best',
ckpt_path: Optional[str] = None,
) -> Optional[_PREDICT_OUTPUT]:
r"""

Expand All @@ -699,9 +696,9 @@ def predict(
return_predictions: Whether to return predictions.
``True`` by default except when an accelerator that spawns processes is used (not supported).

ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict.
ckpt_path: Either ``best`` or path to the checkpoint you wish to predict.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
When the model is given as argument, we load the ckpt path.

Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
Expand Down Expand Up @@ -735,8 +732,9 @@ def predict(
# links data to the trainer
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.predicted_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.predicted_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

results = self._run(model)

Expand Down Expand Up @@ -838,6 +836,15 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT,
self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

if self._ckpt_path:
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

rank_zero_info(f"Loading checkpoint from {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

# ----------------------------
# INSPECT THE CORE LOOPS
# ----------------------------
Expand Down Expand Up @@ -1072,12 +1079,22 @@ def _run_sanity_check(self, ref_model):
# restore the previous stage when the sanity check if finished
self.state.stage = stage

def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
if ckpt_path is None:
def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]:
if model_provided and ckpt_path is None:
# use passed model to function without loading weights
return

fn = self.state.fn.value

if model_connected and ckpt_path is None:
rank_zero_warn(
f"`.{fn}(ckpt_path=None)` was called without a model. "
"The best model of the previous `fit` call will be used. "
f"You can pass `{fn}(ckpt_path='best')` to avoid this warning "
"or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
)
ckpt_path = 'best'

if ckpt_path == 'best':
# if user requests the best checkpoint but we don't have it, error
if not self.checkpoint_callback.best_model_path:
Expand All @@ -1097,13 +1114,6 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)

# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

self.checkpoint_connector.restore_model_weights(ckpt_path)
return ckpt_path

def _call_setup_hook(self, model: 'pl.LightningModule') -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,6 @@ def configure_callbacks(self):
callbacks_after = trainer.callbacks.copy()
assert callbacks_after == callbacks_after_fit

trainer_fn(ckpt_path=None)
trainer_fn(model)
callbacks_after = trainer.callbacks.copy()
assert callbacks_after == callbacks_after_fit
2 changes: 1 addition & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def predict_dataloader(self):

trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
trainer.test(ckpt_path=None)
trainer.test(model)

preds = trainer.predict(model)
assert len(preds) == 2
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template):
# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model)
trainer.test(ckpt_path=None)
trainer.test(model)

# correct result and ok accuracy
assert trainer.state.finished, f"Training failed with {trainer.state}"
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _make_fast_dev_run_assertions(trainer, model):
train_val_step_model = FastDevRunModel()
trainer = Trainer(**trainer_config)
trainer.fit(train_val_step_model)
trainer.test(ckpt_path=None)
trainer.test(train_val_step_model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_val_step_model)
Expand All @@ -121,7 +121,7 @@ def _make_fast_dev_run_assertions(trainer, model):

trainer = Trainer(**trainer_config)
trainer.fit(train_step_only_model)
trainer.test(ckpt_path=None)
trainer.test(train_step_only_model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_step_only_model)
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_dataloader(self):
max_epochs=1,
)
trainer.fit(model)
trainer.test(model, ckpt_path=None)
trainer.test(model)


def test_can_return_tensor_with_more_than_one_element(tmpdir):
Expand Down
14 changes: 9 additions & 5 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
assert trainer.num_training_batches == expected_train_batches
assert trainer.num_val_batches == expected_val_batches

trainer.test(ckpt_path=None)
trainer.test(model)
expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders]
assert trainer.num_test_batches == expected_test_batches

Expand Down Expand Up @@ -507,7 +507,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v
# -------------------------------------------
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
trainer.test(ckpt_path=None)
trainer.test(model)

# when the limit is greater than the number of test batches it should be the num in loaders
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
Expand Down Expand Up @@ -586,7 +586,7 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
assert trainer.num_training_batches == fast_dev_run
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)

trainer.test(ckpt_path=None)
trainer.test(model)
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)

# verify sanity check batches match as expected
Expand Down Expand Up @@ -740,6 +740,8 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers'
):
if stage == 'test':
if ckpt_path in ('specific', 'best'):
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path
trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path)
else:
Expand Down Expand Up @@ -782,6 +784,8 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers'
):
if stage == 'test':
if ckpt_path in ('specific', 'best'):
trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path
trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path)
else:
Expand Down Expand Up @@ -1093,7 +1097,7 @@ def test_dataloader_distributed_sampler(tmpdir):
callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))],
)
trainer.fit(model)
trainer.test(ckpt_path=None)
trainer.test(model)


class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):
Expand Down Expand Up @@ -1613,7 +1617,7 @@ def predict_dataloader(self):

trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
trainer.test(ckpt_path=None)
trainer.test(model)

preds = trainer.predict(model)
assert len(preds) == 2
Expand Down
23 changes: 18 additions & 5 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,24 @@ def predict_step(self, batch, *_):
if save_top_k == 0:
with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"):
trainer_fn(ckpt_path=ckpt_path)
with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"):
trainer_fn(model, ckpt_path=ckpt_path)
else:
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path

trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and
# use the weights from the end of training
trainer_fn(ckpt_path=ckpt_path)
# ckpt_path is None, meaning we don't load any checkpoints and use the provided model
trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) is None

if save_top_k > 0:
# ckpt_path is None with no model provided means load the best weights
with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"):
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
Expand All @@ -701,6 +711,9 @@ def predict_step(self, batch, *_):
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path

trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path


def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""
Expand Down Expand Up @@ -1301,9 +1314,9 @@ def setup(self, model, stage):
if stage == "fit":
trainer.fit(model)
elif stage == "validate":
trainer.validate(model, ckpt_path=None)
trainer.validate(model)
else:
trainer.test(model, ckpt_path=None)
trainer.test(model)

assert trainer.stage == stage
assert trainer.lightning_module.stage == stage
Expand Down