Skip to content

Commit

Permalink
Load ckpt path when model provided in validate/test/predict (#8352)
Browse files Browse the repository at this point in the history
* Change trainer loading behaviour for validate/test/predict

* Fix

* Fix/add tests

* remove

* Cleanups

* Space

* cleanups

* Add CHANGELOG.md

* Move after setup

* Cleanups on logic

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remve

* fix test

* feedback

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update pytorch_lightning/trainer/properties.py

Co-authored-by: Carlos Mocholí <[email protected]>

* Feedback

* Same fix

* Same fix

* Add test for behaviour, modify based on feedback

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Wording

* Apply suggestions from code review

Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>

* Cleanup docs

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Kaushik B <[email protected]>

* feedback

* Fixes to test API

* Add carlos description

* Move logic further

* Move checkpoint connector logic

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
  • Loading branch information
4 people authored Jul 28, 2021
1 parent b256d6a commit aadd2a9
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 54 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))


-
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))



-
Expand Down Expand Up @@ -164,6 +165,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added private `prevent_trainer_and_dataloaders_deepcopy` context manager on the `LightningModule` ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))
- Added support for providing callables to the Lightning CLI instead of types ([#8400](https://github.com/PyTorchLightning/pytorch-lightning/pull/8400))


### Changed

- Decoupled device parsing logic from Accelerator connector to Trainer ([#8180](https://github.com/PyTorchLightning/pytorch-lightning/pull/8180))
Expand Down
9 changes: 3 additions & 6 deletions docs/source/common/test_set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@ To run the test set after training completes, use this method.
trainer.fit(model)
# (1) load the best checkpoint automatically (lightning tracks this for you)
trainer.test()
trainer.test(ckpt_path='best')
# (2) don't load a checkpoint, instead use the model with the latest weights
trainer.test(ckpt_path=None)
# (3) test using a specific checkpoint
# (2) test using a specific checkpoint
trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')
# (4) test with an explicit model (will use this model and not load a checkpoint)
# (3) test with an explicit model (will use this model and not load a checkpoint)
trainer.test(model)
----------
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class TrainerProperties(ABC):
logger: LightningLoggerBase
logger_connector: LoggerConnector
state: TrainerState

# .validate() and .test() set this when they load a checkpoint
validated_ckpt_path: Optional[str] = None
tested_ckpt_path: Optional[str] = None
predicted_ckpt_path: Optional[str] = None
"""
Accelerator properties
"""
Expand Down Expand Up @@ -614,6 +619,15 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

"""
Logging properties
"""
Expand Down
71 changes: 40 additions & 31 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,10 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None:
self.test_dataloaders = None
self.val_dataloaders = None

# .validate() and .test() set this when they load a checkpoint
self.validated_ckpt_path = None
self.tested_ckpt_path = None

# when true, print evaluation results in .validate() and .test()
self.verbose_evaluate = True

self.num_predict_batches = []
self.predicted_ckpt_path = None

def fit(
self,
Expand Down Expand Up @@ -559,7 +554,7 @@ def validate(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = "best",
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
val_dataloaders=None, # noqa TODO: remove with 1.6
Expand All @@ -574,8 +569,8 @@ def validate(
or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples.
ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
verbose: If True, prints the validation results.
Expand Down Expand Up @@ -621,8 +616,9 @@ def validate(
# links data to the trainer
self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.validated_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

# run validate
results = self._run(model)
Expand All @@ -636,7 +632,7 @@ def test(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = "best",
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
test_dataloaders=None, # noqa TODO: remove with 1.6
Expand All @@ -652,8 +648,8 @@ def test(
or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples.
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
verbose: If True, prints the test results.
Expand Down Expand Up @@ -699,8 +695,9 @@ def test(
# links data to the trainer
self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.tested_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

# run test
results = self._run(model)
Expand All @@ -716,7 +713,7 @@ def predict(
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = "best",
ckpt_path: Optional[str] = None,
) -> Optional[_PREDICT_OUTPUT]:
r"""
Expand All @@ -734,9 +731,9 @@ def predict(
return_predictions: Whether to return predictions.
``True`` by default except when an accelerator that spawns processes is used (not supported).
ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
ckpt_path: Either ``best`` or path to the checkpoint you wish to predict.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
Expand Down Expand Up @@ -770,8 +767,9 @@ def predict(
# links data to the trainer
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
self.predicted_ckpt_path = self.__load_ckpt_weights(ckpt_path)
self.predicted_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

results = self._run(model)

Expand Down Expand Up @@ -856,6 +854,15 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data(model)
self.callback_connector._attach_model_callbacks(model, self)

if self._ckpt_path:
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

rank_zero_info(f"Loading checkpoint from {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand Down Expand Up @@ -910,7 +917,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,

# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()

Expand Down Expand Up @@ -1126,12 +1132,22 @@ def _run_sanity_check(self, ref_model):
# restore the previous stage when the sanity check if finished
self.state.stage = stage

def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
if ckpt_path is None:
def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]:
if model_provided and ckpt_path is None:
# use passed model to function without loading weights
return

fn = self.state.fn.value

if model_connected and ckpt_path is None:
rank_zero_warn(
f"`.{fn}(ckpt_path=None)` was called without a model. "
"The best model of the previous `fit` call will be used. "
f"You can pass `{fn}(ckpt_path='best')` to avoid this warning "
"or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
)
ckpt_path = "best"

if ckpt_path == "best":
# if user requests the best checkpoint but we don't have it, error
if not self.checkpoint_callback.best_model_path:
Expand All @@ -1151,13 +1167,6 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`"
)

# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

self.checkpoint_connector.restore_model_weights(ckpt_path)
return ckpt_path

def _call_setup_hook(self, model: "pl.LightningModule") -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def configure_callbacks(self):
callbacks_after = trainer.callbacks.copy()
assert callbacks_after == callbacks_after_fit

trainer_fn(ckpt_path=None)
trainer_fn(model)
callbacks_after = trainer.callbacks.copy()
assert callbacks_after == callbacks_after_fit
2 changes: 1 addition & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def predict_dataloader(self):

trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
trainer.test(ckpt_path=None)
trainer.test(model)

preds = trainer.predict(model)
assert len(preds) == 2
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template):
# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model)
trainer.test(ckpt_path=None)
trainer.test(model)

# correct result and ok accuracy
assert trainer.state.finished, f"Training failed with {trainer.state}"
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _make_fast_dev_run_assertions(trainer, model):
train_val_step_model = FastDevRunModel()
trainer = Trainer(**trainer_config)
trainer.fit(train_val_step_model)
trainer.test(ckpt_path=None)
trainer.test(train_val_step_model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_val_step_model)
Expand All @@ -120,7 +120,7 @@ def _make_fast_dev_run_assertions(trainer, model):

trainer = Trainer(**trainer_config)
trainer.fit(train_step_only_model)
trainer.test(ckpt_path=None)
trainer.test(train_step_only_model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_step_only_model)
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_dataloader(self):
default_root_dir=tmpdir, accelerator="dp", gpus=2, limit_train_batches=2, limit_val_batches=2, max_epochs=1
)
trainer.fit(model)
trainer.test(model, ckpt_path=None)
trainer.test(model)


def test_can_return_tensor_with_more_than_one_element(tmpdir):
Expand Down
14 changes: 9 additions & 5 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
assert trainer.num_training_batches == expected_train_batches
assert trainer.num_val_batches == expected_val_batches

trainer.test(ckpt_path=None)
trainer.test(model)
expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders]
assert trainer.num_test_batches == expected_test_batches

Expand Down Expand Up @@ -474,7 +474,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v
# -------------------------------------------
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
trainer.test(ckpt_path=None)
trainer.test(model)

# when the limit is greater than the number of test batches it should be the num in loaders
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
Expand Down Expand Up @@ -549,7 +549,7 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
assert trainer.num_training_batches == fast_dev_run
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)

trainer.test(ckpt_path=None)
trainer.test(model)
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)

# verify sanity check batches match as expected
Expand Down Expand Up @@ -685,6 +685,8 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
):
if stage == "test":
if ckpt_path in ("specific", "best"):
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path)
else:
Expand Down Expand Up @@ -722,6 +724,8 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
):
if stage == "test":
if ckpt_path in ("specific", "best"):
trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path
trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path)
else:
Expand Down Expand Up @@ -950,7 +954,7 @@ def test_dataloader_distributed_sampler(tmpdir):
callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))],
)
trainer.fit(model)
trainer.test(ckpt_path=None)
trainer.test(model)


class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):
Expand Down Expand Up @@ -1444,7 +1448,7 @@ def predict_dataloader(self):

trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
trainer.test(ckpt_path=None)
trainer.test(model)

preds = trainer.predict(model)
assert len(preds) == 2
Expand Down
Loading

0 comments on commit aadd2a9

Please sign in to comment.