diff --git a/CHANGELOG.md b/CHANGELOG.md index 7872af715d68a..129744c10f0f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477)) -- +- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) + - @@ -164,6 +165,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added private `prevent_trainer_and_dataloaders_deepcopy` context manager on the `LightningModule` ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472)) - Added support for providing callables to the Lightning CLI instead of types ([#8400](https://github.com/PyTorchLightning/pytorch-lightning/pull/8400)) + ### Changed - Decoupled device parsing logic from Accelerator connector to Trainer ([#8180](https://github.com/PyTorchLightning/pytorch-lightning/pull/8180)) diff --git a/docs/source/common/test_set.rst b/docs/source/common/test_set.rst index 5703d71d956de..dab9f15b0b27b 100644 --- a/docs/source/common/test_set.rst +++ b/docs/source/common/test_set.rst @@ -20,15 +20,12 @@ To run the test set after training completes, use this method. trainer.fit(model) # (1) load the best checkpoint automatically (lightning tracks this for you) - trainer.test() + trainer.test(ckpt_path='best') - # (2) don't load a checkpoint, instead use the model with the latest weights - trainer.test(ckpt_path=None) - - # (3) test using a specific checkpoint + # (2) test using a specific checkpoint trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt') - # (4) test with an explicit model (will use this model and not load a checkpoint) + # (3) test with an explicit model (will use this model and not load a checkpoint) trainer.test(model) ---------- diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 7d1ae185c42e6..097138d9ad508 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -69,6 +69,11 @@ class TrainerProperties(ABC): logger: LightningLoggerBase logger_connector: LoggerConnector state: TrainerState + + # .validate() and .test() set this when they load a checkpoint + validated_ckpt_path: Optional[str] = None + tested_ckpt_path: Optional[str] = None + predicted_ckpt_path: Optional[str] = None """ Accelerator properties """ @@ -614,6 +619,15 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop if self.predicting: return self.predict_loop + @property + def _ckpt_path(self) -> Optional[str]: + if self.state.fn == TrainerFn.VALIDATING: + return self.validated_ckpt_path + if self.state.fn == TrainerFn.TESTING: + return self.tested_ckpt_path + if self.state.fn == TrainerFn.PREDICTING: + return self.predicted_ckpt_path + """ Logging properties """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5f3d18ebc4e66..50ec5dec81376 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -489,15 +489,10 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None: self.test_dataloaders = None self.val_dataloaders = None - # .validate() and .test() set this when they load a checkpoint - self.validated_ckpt_path = None - self.tested_ckpt_path = None - # when true, print evaluation results in .validate() and .test() self.verbose_evaluate = True self.num_predict_batches = [] - self.predicted_ckpt_path = None def fit( self, @@ -559,7 +554,7 @@ def validate( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = "best", + ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, val_dataloaders=None, # noqa TODO: remove with 1.6 @@ -574,8 +569,8 @@ def validate( or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model from the previous ``trainer.fit`` call will be loaded. verbose: If True, prints the validation results. @@ -621,8 +616,9 @@ def validate( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - if not model_provided: - self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) + self.validated_ckpt_path = self.__set_ckpt_path( + ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None + ) # run validate results = self._run(model) @@ -636,7 +632,7 @@ def test( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = "best", + ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, test_dataloaders=None, # noqa TODO: remove with 1.6 @@ -652,8 +648,8 @@ def test( or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model from the previous ``trainer.fit`` call will be loaded. verbose: If True, prints the test results. @@ -699,8 +695,9 @@ def test( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - if not model_provided: - self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) + self.tested_ckpt_path = self.__set_ckpt_path( + ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None + ) # run test results = self._run(model) @@ -716,7 +713,7 @@ def predict( dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = "best", + ckpt_path: Optional[str] = None, ) -> Optional[_PREDICT_OUTPUT]: r""" @@ -734,9 +731,9 @@ def predict( return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). - ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. + ckpt_path: Either ``best`` or path to the checkpoint you wish to predict. + If ``None`` and the model instance was passed, use the current weights. + Otherwise, the best model from the previous ``trainer.fit`` call will be loaded. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. @@ -770,8 +767,9 @@ def predict( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - if not model_provided: - self.predicted_ckpt_path = self.__load_ckpt_weights(ckpt_path) + self.predicted_ckpt_path = self.__set_ckpt_path( + ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None + ) results = self._run(model) @@ -856,6 +854,15 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.data_connector.prepare_data(model) self.callback_connector._attach_model_callbacks(model, self) + if self._ckpt_path: + # only one process running at this point for TPUs, as spawn isn't triggered yet + # todo: move this logic internally within the barrier. + if not self._device_type == DeviceType.TPU: + self.training_type_plugin.barrier() + + rank_zero_info(f"Loading checkpoint from {self._ckpt_path}") + self.checkpoint_connector.restore_model_weights(self._ckpt_path) + # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -910,7 +917,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() - # restore optimizers, etc. self.checkpoint_connector.restore_training_state() @@ -1126,12 +1132,22 @@ def _run_sanity_check(self, ref_model): # restore the previous stage when the sanity check if finished self.state.stage = stage - def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: - if ckpt_path is None: + def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]: + if model_provided and ckpt_path is None: + # use passed model to function without loading weights return fn = self.state.fn.value + if model_connected and ckpt_path is None: + rank_zero_warn( + f"`.{fn}(ckpt_path=None)` was called without a model. " + "The best model of the previous `fit` call will be used. " + f"You can pass `{fn}(ckpt_path='best')` to avoid this warning " + "or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model." + ) + ckpt_path = "best" + if ckpt_path == "best": # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: @@ -1151,13 +1167,6 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`" ) - - # only one process running at this point for TPUs, as spawn isn't triggered yet - # todo: move this logic internally within the barrier. - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() - - self.checkpoint_connector.restore_model_weights(ckpt_path) return ckpt_path def _call_setup_hook(self, model: "pl.LightningModule") -> None: diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 75f0f64b7d7f5..dedc74f021f81 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -98,6 +98,6 @@ def configure_callbacks(self): callbacks_after = trainer.callbacks.copy() assert callbacks_after == callbacks_after_fit - trainer_fn(ckpt_path=None) + trainer_fn(model) callbacks_after = trainer.callbacks.copy() assert callbacks_after == callbacks_after_fit diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 608639d2459df..032db105c8fc4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -796,7 +796,7 @@ def predict_dataloader(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - trainer.test(ckpt_path=None) + trainer.test(model) preds = trainer.predict(model) assert len(preds) == 2 diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index aaa374836def1..d1d870fa30116 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -360,7 +360,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template): # fit model trainer = Trainer(**trainer_options) trainer.fit(model) - trainer.test(ckpt_path=None) + trainer.test(model) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index e9d427361ed50..cff0c8a43727d 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -107,7 +107,7 @@ def _make_fast_dev_run_assertions(trainer, model): train_val_step_model = FastDevRunModel() trainer = Trainer(**trainer_config) trainer.fit(train_val_step_model) - trainer.test(ckpt_path=None) + trainer.test(train_val_step_model) assert trainer.state.finished, f"Training failed with {trainer.state}" _make_fast_dev_run_assertions(trainer, train_val_step_model) @@ -120,7 +120,7 @@ def _make_fast_dev_run_assertions(trainer, model): trainer = Trainer(**trainer_config) trainer.fit(train_step_only_model) - trainer.test(ckpt_path=None) + trainer.test(train_step_only_model) assert trainer.state.finished, f"Training failed with {trainer.state}" _make_fast_dev_run_assertions(trainer, train_step_only_model) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index f07344397b178..00811f736891f 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -201,7 +201,7 @@ def test_dataloader(self): default_root_dir=tmpdir, accelerator="dp", gpus=2, limit_train_batches=2, limit_val_batches=2, max_epochs=1 ) trainer.fit(model) - trainer.test(model, ckpt_path=None) + trainer.test(model) def test_can_return_tensor_with_more_than_one_element(tmpdir): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 44910119672eb..93b6a1a9288c9 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -439,7 +439,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim assert trainer.num_training_batches == expected_train_batches assert trainer.num_val_batches == expected_val_batches - trainer.test(ckpt_path=None) + trainer.test(model) expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders] assert trainer.num_test_batches == expected_test_batches @@ -474,7 +474,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v # ------------------------------------------- assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders) - trainer.test(ckpt_path=None) + trainer.test(model) # when the limit is greater than the number of test batches it should be the num in loaders test_dataloader_lengths = [len(x) for x in model.test_dataloader()] @@ -549,7 +549,7 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): assert trainer.num_training_batches == fast_dev_run assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders) - trainer.test(ckpt_path=None) + trainer.test(model) assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders) # verify sanity check batches match as expected @@ -685,6 +685,8 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers', ): if stage == "test": + if ckpt_path in ("specific", "best"): + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) else: @@ -722,6 +724,8 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers', ): if stage == "test": + if ckpt_path in ("specific", "best"): + trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) else: @@ -950,7 +954,7 @@ def test_dataloader_distributed_sampler(tmpdir): callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))], ) trainer.fit(model) - trainer.test(ckpt_path=None) + trainer.test(model) class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): @@ -1444,7 +1448,7 @@ def predict_dataloader(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - trainer.test(ckpt_path=None) + trainer.test(model) preds = trainer.predict(model) assert len(preds) == 2 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4ea233c8c2471..86ca0d1fc5618 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -639,14 +639,24 @@ def predict_step(self, batch, *_): if save_top_k == 0: with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): trainer_fn(ckpt_path=ckpt_path) + with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): + trainer_fn(model, ckpt_path=ckpt_path) else: trainer_fn(ckpt_path=ckpt_path) assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path + + trainer_fn(model, ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: - # ckpt_path is None, meaning we don't load any checkpoints and - # use the weights from the end of training - trainer_fn(ckpt_path=ckpt_path) + # ckpt_path is None, meaning we don't load any checkpoints and use the provided model + trainer_fn(model, ckpt_path=ckpt_path) assert getattr(trainer, path_attr) is None + + if save_top_k > 0: + # ckpt_path is None with no model provided means load the best weights + with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"): + trainer_fn(ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -661,6 +671,9 @@ def predict_step(self, batch, *_): trainer_fn(ckpt_path=ckpt_path) assert getattr(trainer, path_attr) == ckpt_path + trainer_fn(model, ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == ckpt_path + def test_disabled_training(tmpdir): """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" @@ -1223,9 +1236,9 @@ def setup(self, model, stage): if stage == "fit": trainer.fit(model) elif stage == "validate": - trainer.validate(model, ckpt_path=None) + trainer.validate(model) else: - trainer.test(model, ckpt_path=None) + trainer.test(model) assert trainer.stage == stage assert trainer.lightning_module.stage == stage