From 6c8ab1c39f81f03ebcde9e67a951ad56d1f4b917 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 19 Oct 2021 15:02:38 +0530 Subject: [PATCH 01/14] raise misconfig for trainer.eval --- .../trainer/configuration_validator.py | 82 +++++++++++-------- tests/trainer/test_config_validator.py | 50 ++++++----- 2 files changed, 69 insertions(+), 63 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 88c319ac57431..533bc4a3755d2 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -30,15 +30,16 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule """ if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): __verify_train_loop_configuration(trainer, model) - __verify_eval_loop_configuration(model, "val") + __verify_fit_val_loop_configuration(trainer, model) __verify_manual_optimization_support(trainer, model) __check_training_step_requires_dataloader_iter(model) elif trainer.state.fn == TrainerFn.VALIDATING: - __verify_eval_loop_configuration(model, "val") + __verify_eval_loop_configuration(trainer, model, "val") elif trainer.state.fn == TrainerFn.TESTING: - __verify_eval_loop_configuration(model, "test") + __verify_eval_loop_configuration(trainer, model, "test") elif trainer.state.fn == TrainerFn.PREDICTING: - __verify_predict_loop_configuration(trainer, model) + __verify_eval_loop_configuration(trainer, model, "predict") + __verify_dp_batch_transfer_support(trainer, model) _check_add_get_queue(model) # TODO(@daniellepintz): Delete _check_progress_bar in v1.7 @@ -110,8 +111,8 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: rank_zero_warn( "When using `Trainer(accumulate_grad_batches != 1)` and overriding" - "`LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch" - "(rather, they are called on every optimization step)." + " `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch" + " (rather, they are called on every optimization step)." ) @@ -143,53 +144,62 @@ def _check_on_post_move_to_device(model: "pl.LightningModule") -> None: ) -def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> None: - loader_name = f"{stage}_dataloader" - step_name = "validation_step" if stage == "val" else "test_step" - - has_loader = is_overridden(loader_name, model) - has_step = is_overridden(step_name, model) +def __verify_fit_val_loop_configuration(trainer, model: "pl.LightningModule") -> None: + has_val_loader = trainer.data_connector._val_dataloader_source.is_defined() + has_val_step = is_overridden("validation_step", model) - if has_loader and not has_step: - rank_zero_warn(f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop") - if has_step and not has_loader: - rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop") + if has_val_loader and not has_val_step: + rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.") + if has_val_step and not has_val_loader: + rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.") # ---------------------------------------------- - # verify model does not have - # - on_val_dataloader - # - on_test_dataloader + # verify model does not have on_val_dataloader # ---------------------------------------------- has_on_val_dataloader = is_overridden("on_val_dataloader", model) if has_on_val_dataloader: rank_zero_deprecation( - "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + "Method `on_val_dataloader` is deprecated and will be removed in v1.7.0." " Please use `val_dataloader()` directly." ) - has_on_test_dataloader = is_overridden("on_test_dataloader", model) - if has_on_test_dataloader: - rank_zero_deprecation( - "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." - " Please use `test_dataloader()` directly." - ) +def __verify_eval_loop_configuration(trainer, model: "pl.LightningModule", stage: str) -> None: + loader_name = f"{stage}_dataloader" + step_name = "validation_step" if stage == "val" else f"{stage}_step" + trainer_method = "validate" if stage == "val" else stage + on_eval_hook = f"on_{loader_name}" + + has_loader = is_overridden(loader_name, model) + has_loader = getattr(trainer.data_connector, f'_{stage}_dataloader_source').is_defined() + has_step = is_overridden(step_name, model) + has_on_eval_dataloader = is_overridden(on_eval_hook, model) -def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: - has_predict_dataloader = trainer.data_connector._predict_dataloader_source.is_defined() - if not has_predict_dataloader: - raise MisconfigurationException("Dataloader not found for `Trainer.predict`") # ---------------------------------------------- - # verify model does not have - # - on_predict_dataloader + # verify model does not have on_eval_dataloader # ---------------------------------------------- - has_on_predict_dataloader = is_overridden("on_predict_dataloader", model) - if has_on_predict_dataloader: + if has_on_eval_dataloader: rank_zero_deprecation( - "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." - " Please use `predict_dataloader()` directly." + f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will" + f" be removed in v1.7.0. Please use `{loader_name}()` directly." ) + # ----------------------------------- + # verify model has an eval_dataloader + # ----------------------------------- + if not has_loader: + raise MisconfigurationException(f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`.") + + # predict_step is not required to be overridden + if stage == "predict": + return + + # ----------------------------------- + # verify model has an eval_step + # ----------------------------------- + if not has_step: + raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.") + def __verify_dp_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: """Raise Misconfiguration exception since these hooks are not supported in DP mode.""" diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 9c452052d73e7..348cdfb9fbafe 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -52,50 +52,51 @@ def test_fit_val_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop - with pytest.warns(UserWarning, match=r"you passed in a val_dataloader but have no validation_step"): + with pytest.warns(UserWarning, match=r"You passed in a `val_dataloader` but have no `validation_step`"): model = BoringModel() model.validation_step = None trainer.fit(model) # has val loop but no val data - with pytest.warns(UserWarning, match=r"you defined a validation_step but have no val_dataloader"): + with pytest.warns(UserWarning, match=r"You defined a `validation_step` but have no `val_dataloader`"): model = BoringModel() model.val_dataloader = None trainer.fit(model) -def test_test_loop_config(tmpdir): - """When either test loop or test data are missing.""" +def test_eval_loop_config(tmpdir): + """When either eval step or eval data is missing.""" trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + # has val step but no val data + with pytest.raises(MisconfigurationException, match=r"No `val_dataloader\(\)` method defined"): + model = BoringModel() + model.val_dataloader = None + trainer.validate(model) + + # has test data but no val step + with pytest.raises(MisconfigurationException, match=r"No `validation_step\(\)` method defined"): + model = BoringModel() + model.validation_step = None + trainer.validate(model) + # has test loop but no test data - with pytest.warns(UserWarning, match=r"you defined a test_step but have no test_dataloader"): + with pytest.raises(MisconfigurationException, match=r"No `test_dataloader\(\)` method defined"): model = BoringModel() model.test_dataloader = None trainer.test(model) - # has test data but no test loop - with pytest.warns(UserWarning, match=r"you passed in a test_dataloader but have no test_step"): + # has test data but no test step + with pytest.raises(MisconfigurationException, match=r"No `test_step\(\)` method defined"): model = BoringModel() model.test_step = None trainer.test(model) - -def test_val_loop_config(tmpdir): - """When either validation loop or validation data are missing.""" - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - - # has val loop but no val data - with pytest.warns(UserWarning, match=r"you defined a validation_step but have no val_dataloader"): - model = BoringModel() - model.val_dataloader = None - trainer.validate(model) - - # has val data but no val loop - with pytest.warns(UserWarning, match=r"you passed in a val_dataloader but have no validation_step"): + # has predict step but no predict data + with pytest.raises(MisconfigurationException, match=r"No `predict_dataloader\(\)` method defined"): model = BoringModel() - model.validation_step = None - trainer.validate(model) + model.predict_dataloader = None + trainer.predict(model) @pytest.mark.parametrize("datamodule", [False, True]) @@ -130,11 +131,6 @@ def predict_dataloader(self): assert len(results) == 2 assert results[0][0].shape == torch.Size([1, 2]) - model.predict_dataloader = None - - with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): - trainer.predict(model) - def test_trainer_manual_optimization_config(tmpdir): """Test error message when requesting Trainer features unsupported with manual optimization.""" From 9c13a8d35cd13115082bbe6a10a247b205485b17 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 19 Oct 2021 15:10:06 +0530 Subject: [PATCH 02/14] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index efa420bebcd8c..dfaee9a8ddc2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -323,6 +323,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) +- Raise `MisconfigurationException` instead of warning if `trainer.{validate/test}` is missing required methods ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061)) + ### Deprecated @@ -397,6 +399,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924)) + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) From 392529359aaf5acfcaeed3aa7c0d9cd20b501b89 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 19 Oct 2021 16:55:19 +0530 Subject: [PATCH 03/14] fix tests --- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- tests/deprecated_api/test_remove_1-7.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 533bc4a3755d2..7731bfb5f1644 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -91,14 +91,14 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin has_on_train_dataloader = is_overridden("on_train_dataloader", model) if has_on_train_dataloader: rank_zero_deprecation( - "Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + "Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." " Please use `train_dataloader()` directly." ) has_on_val_dataloader = is_overridden("on_val_dataloader", model) if has_on_val_dataloader: rank_zero_deprecation( - "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." " Please use `val_dataloader()` directly." ) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index edf85c11766e3..a8f363d756bee 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -162,27 +162,27 @@ def _run(model, task="fit"): model = CustomBoringModel() with pytest.deprecated_call( - match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + match="Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "fit") with pytest.deprecated_call( - match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "fit") with pytest.deprecated_call( - match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "validate") with pytest.deprecated_call( - match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + match="Method `on_test_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "test") with pytest.deprecated_call( - match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." + match="Method `on_predict_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "predict") From 1fd2d3c7038e7c71ab2be12406fcf1cef926b860 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 19 Oct 2021 17:20:23 +0530 Subject: [PATCH 04/14] update tests --- tests/helpers/test_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/helpers/test_models.py b/tests/helpers/test_models.py index 8e5f85632bbc5..a6e5437115d5e 100644 --- a/tests/helpers/test_models.py +++ b/tests/helpers/test_models.py @@ -40,7 +40,9 @@ def test_models(tmpdir, data_class, model_class): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model, datamodule=dm) - trainer.test(model, datamodule=dm) + + if dm is not None: + trainer.test(model, datamodule=dm) model.to_torchscript() if data_class: From 1e7d49794b16230e4b8ff0527c1199738ab70edc Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 19 Oct 2021 19:00:07 +0530 Subject: [PATCH 05/14] update tests --- tests/callbacks/test_pruning.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 03070a8f1dc1f..1c1f84b5b95a0 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -30,8 +30,6 @@ class TestModel(BoringModel): - test_step = None - def __init__(self): super().__init__() self.layer = Sequential( From 7c0f2c3caa5b75552565e14ca2b6011a77e34353 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 19 Oct 2021 10:19:11 -0400 Subject: [PATCH 06/14] fix tests --- tests/models/test_restore.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 569c861eb32f0..3a973415cb932 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -333,11 +333,11 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" - pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + pretrained_model = CustomClassificationModelDP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) - new_trainer.test(pretrained_model) + new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() @@ -385,7 +385,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): # run test set new_trainer = Trainer(**trainer_options) - new_trainer.test(pretrained_model) + new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() From b8f1d1dc3ad7f156a4ffbbfbe3bfdcf165b1f341 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 19 Oct 2021 20:09:56 +0530 Subject: [PATCH 07/14] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dfaee9a8ddc2c..419f63dafec66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -323,7 +323,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) -- Raise `MisconfigurationException` instead of warning if `trainer.{validate/test}` is missing required methods ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061)) +- Raise `MisconfigurationException` instead of warning if `trainer.{validate/test}` is missing required methods ([#10016](https://github.com/PyTorchLightning/pytorch-lightning/pull/10016)) ### Deprecated From dde6a6b2614819f5d2de3880a52262a62051cd7a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 20 Oct 2021 19:40:15 +0530 Subject: [PATCH 08/14] rebase and update tests --- .../trainer/configuration_validator.py | 62 ++++++++----------- tests/plugins/test_deepspeed_plugin.py | 5 +- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 7731bfb5f1644..62029b54e26a7 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -29,8 +29,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule """ if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): - __verify_train_loop_configuration(trainer, model) - __verify_fit_val_loop_configuration(trainer, model) + __verify_train_val_loop_configuration(trainer, model) __verify_manual_optimization_support(trainer, model) __check_training_step_requires_dataloader_iter(model) elif trainer.state.fn == TrainerFn.VALIDATING: @@ -52,7 +51,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule _check_dl_idx_in_on_train_batch_hooks(trainer, model) -def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: +def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- @@ -84,9 +83,7 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin ) # ---------------------------------------------- - # verify model does not have - # - on_train_dataloader - # - on_val_dataloader + # verify model does not have on_train_dataloader # ---------------------------------------------- has_on_train_dataloader = is_overridden("on_train_dataloader", model) if has_on_train_dataloader: @@ -95,13 +92,6 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin " Please use `train_dataloader()` directly." ) - has_on_val_dataloader = is_overridden("on_val_dataloader", model) - if has_on_val_dataloader: - rank_zero_deprecation( - "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." - " Please use `val_dataloader()` directly." - ) - trainer.overriden_optimizer_step = is_overridden("optimizer_step", model) trainer.overriden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model) automatic_optimization = model.automatic_optimization @@ -115,6 +105,28 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin " (rather, they are called on every optimization step)." ) + # ----------------------------------- + # verify model for val loop + # ----------------------------------- + + has_val_loader = trainer.data_connector._val_dataloader_source.is_defined() + has_val_step = is_overridden("validation_step", model) + + if has_val_loader and not has_val_step: + rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.") + if has_val_step and not has_val_loader: + rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.") + + # ---------------------------------------------- + # verify model does not have on_val_dataloader + # ---------------------------------------------- + has_on_val_dataloader = is_overridden("on_val_dataloader", model) + if has_on_val_dataloader: + rank_zero_deprecation( + "Method `on_val_dataloader` is deprecated and will be removed in v1.7.0." + " Please use `val_dataloader()` directly." + ) + def _check_progress_bar(model: "pl.LightningModule") -> None: r""" @@ -144,34 +156,14 @@ def _check_on_post_move_to_device(model: "pl.LightningModule") -> None: ) -def __verify_fit_val_loop_configuration(trainer, model: "pl.LightningModule") -> None: - has_val_loader = trainer.data_connector._val_dataloader_source.is_defined() - has_val_step = is_overridden("validation_step", model) - - if has_val_loader and not has_val_step: - rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.") - if has_val_step and not has_val_loader: - rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.") - - # ---------------------------------------------- - # verify model does not have on_val_dataloader - # ---------------------------------------------- - has_on_val_dataloader = is_overridden("on_val_dataloader", model) - if has_on_val_dataloader: - rank_zero_deprecation( - "Method `on_val_dataloader` is deprecated and will be removed in v1.7.0." - " Please use `val_dataloader()` directly." - ) - - -def __verify_eval_loop_configuration(trainer, model: "pl.LightningModule", stage: str) -> None: +def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None: loader_name = f"{stage}_dataloader" step_name = "validation_step" if stage == "val" else f"{stage}_step" trainer_method = "validate" if stage == "val" else stage on_eval_hook = f"on_{loader_name}" has_loader = is_overridden(loader_name, model) - has_loader = getattr(trainer.data_connector, f'_{stage}_dataloader_source').is_defined() + has_loader = getattr(trainer.data_connector, f"_{stage}_dataloader_source").is_defined() has_step = is_overridden(step_name, model) has_on_eval_dataloader = is_overridden(on_eval_hook, model) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 53b7bdbf7f0f0..76cb1cd1336b2 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -914,8 +914,9 @@ def test_dataloader(self): gpus=1, fast_dev_run=True, ) - trainer.fit(model, datamodule=TestSetupIsCalledDataModule()) - trainer.test(model) + dm = TestSetupIsCalledDataModule() + trainer.fit(model, datamodule=dm) + trainer.test(model, datamodule=dm) @mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True) From e5db8efb9df74edcf4dc0c33852f2d99661a6dbd Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 20 Oct 2021 20:05:42 +0530 Subject: [PATCH 09/14] update test --- pytorch_lightning/trainer/configuration_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 62029b54e26a7..7058538983464 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -123,7 +123,7 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh has_on_val_dataloader = is_overridden("on_val_dataloader", model) if has_on_val_dataloader: rank_zero_deprecation( - "Method `on_val_dataloader` is deprecated and will be removed in v1.7.0." + "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." " Please use `val_dataloader()` directly." ) From 8bc11c935f0b6c9f5005d3c57e3264f1ec3fa82e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 20 Oct 2021 22:19:11 +0530 Subject: [PATCH 10/14] forward check --- .../trainer/configuration_validator.py | 4 ++ tests/trainer/test_config_validator.py | 38 ++++++++++++++----- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 7058538983464..fef78e1440be6 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -184,6 +184,10 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning # predict_step is not required to be overridden if stage == "predict": + if model.predict_step is None or not is_overridden("forward", model): + raise MisconfigurationException( + "`Trainer.predict` requires either `forward` or `predict_step` method to run." + ) return # ----------------------------------- diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 348cdfb9fbafe..d4b37035b9d14 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -69,33 +69,51 @@ def test_eval_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has val step but no val data + model = BoringModel() + model.val_dataloader = None with pytest.raises(MisconfigurationException, match=r"No `val_dataloader\(\)` method defined"): - model = BoringModel() - model.val_dataloader = None trainer.validate(model) # has test data but no val step + model = BoringModel() + model.validation_step = None with pytest.raises(MisconfigurationException, match=r"No `validation_step\(\)` method defined"): - model = BoringModel() - model.validation_step = None trainer.validate(model) # has test loop but no test data + model = BoringModel() + model.test_dataloader = None with pytest.raises(MisconfigurationException, match=r"No `test_dataloader\(\)` method defined"): - model = BoringModel() - model.test_dataloader = None trainer.test(model) # has test data but no test step + model = BoringModel() + model.test_step = None with pytest.raises(MisconfigurationException, match=r"No `test_step\(\)` method defined"): - model = BoringModel() - model.test_step = None trainer.test(model) # has predict step but no predict data + model = BoringModel() + model.predict_dataloader = None with pytest.raises(MisconfigurationException, match=r"No `predict_dataloader\(\)` method defined"): - model = BoringModel() - model.predict_dataloader = None + trainer.predict(model) + + # has predict data but no forward + model = BoringModel() + model.forward = None + with pytest.raises(MisconfigurationException, match=r"requires either `forward` or `predict_step` method to run."): + trainer.predict(model) + + # has predict data but no predict_step + model = BoringModel() + model.predict_step = None + with pytest.raises(MisconfigurationException, match=r"requires either `forward` or `predict_step` method to run."): + trainer.predict(model) + + # has predict data but no forward + model = BoringModel() + model.forward = None + with pytest.raises(MisconfigurationException, match=r"requires either `forward` or `predict_step` method to run."): trainer.predict(model) From e388fb60016e4a69bc6d0c1b7d655700114db488 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 21 Oct 2021 02:49:40 +0530 Subject: [PATCH 11/14] Update pytorch_lightning/trainer/configuration_validator.py --- pytorch_lightning/trainer/configuration_validator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index fef78e1440be6..23521d6a659d6 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -162,7 +162,6 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning trainer_method = "validate" if stage == "val" else stage on_eval_hook = f"on_{loader_name}" - has_loader = is_overridden(loader_name, model) has_loader = getattr(trainer.data_connector, f"_{stage}_dataloader_source").is_defined() has_step = is_overridden(step_name, model) has_on_eval_dataloader = is_overridden(on_eval_hook, model) From 36498e6134f1fed5be9e38afe2cd31b8e1b58412 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 22 Oct 2021 14:06:51 +0530 Subject: [PATCH 12/14] improve err msg --- pytorch_lightning/trainer/configuration_validator.py | 8 ++++---- tests/trainer/test_config_validator.py | 10 ++-------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 23521d6a659d6..688721cb76716 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -183,10 +183,10 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning # predict_step is not required to be overridden if stage == "predict": - if model.predict_step is None or not is_overridden("forward", model): - raise MisconfigurationException( - "`Trainer.predict` requires either `forward` or `predict_step` method to run." - ) + if model.predict_step is None: + raise MisconfigurationException("`predict_step` cannot be None to run `Trainer.predict`") + elif not has_step and not is_overridden("forward", model): + raise MisconfigurationException("`Trainer.predict` requires `forward` method to run.") return # ----------------------------------- diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index d4b37035b9d14..ed5ea11322f52 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -98,22 +98,16 @@ def test_eval_loop_config(tmpdir): with pytest.raises(MisconfigurationException, match=r"No `predict_dataloader\(\)` method defined"): trainer.predict(model) - # has predict data but no forward - model = BoringModel() - model.forward = None - with pytest.raises(MisconfigurationException, match=r"requires either `forward` or `predict_step` method to run."): - trainer.predict(model) - # has predict data but no predict_step model = BoringModel() model.predict_step = None - with pytest.raises(MisconfigurationException, match=r"requires either `forward` or `predict_step` method to run."): + with pytest.raises(MisconfigurationException, match=r"`predict_step` cannot be None."): trainer.predict(model) # has predict data but no forward model = BoringModel() model.forward = None - with pytest.raises(MisconfigurationException, match=r"requires either `forward` or `predict_step` method to run."): + with pytest.raises(MisconfigurationException, match=r"requires `forward` method to run."): trainer.predict(model) From c4a2e7edfa69dcca3154ac5d7cf4ebca724d2adc Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 25 Oct 2021 21:29:06 +0530 Subject: [PATCH 13/14] private --- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 5944930d7c93b..f9a165b4571ff 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -109,7 +109,7 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh # verify model for val loop # ----------------------------------- - has_val_loader = trainer.data_connector._val_dataloader_source.is_defined() + has_val_loader = trainer._data_connector._val_dataloader_source.is_defined() has_val_step = is_overridden("validation_step", model) if has_val_loader and not has_val_step: @@ -162,7 +162,7 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning trainer_method = "validate" if stage == "val" else stage on_eval_hook = f"on_{loader_name}" - has_loader = getattr(trainer.data_connector, f"_{stage}_dataloader_source").is_defined() + has_loader = getattr(trainer._data_connector, f"_{stage}_dataloader_source").is_defined() has_step = is_overridden(step_name, model) has_on_eval_dataloader = is_overridden(on_eval_hook, model) From e362da26dfcbb5ecb5fbb4df9aa7ec2df79406d1 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 25 Oct 2021 23:08:29 +0530 Subject: [PATCH 14/14] update tests --- tests/callbacks/test_quantization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index ec2bb66110a2e..fa2ee767bdc8c 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -224,8 +224,8 @@ def test_quantization_val_test_predict(tmpdir): max_epochs=4, ) trainer.fit(val_test_predict_qmodel, datamodule=dm) - trainer.validate(model=val_test_predict_qmodel, verbose=False) - trainer.test(model=val_test_predict_qmodel, verbose=False) + trainer.validate(model=val_test_predict_qmodel, datamodule=dm, verbose=False) + trainer.test(model=val_test_predict_qmodel, datamodule=dm, verbose=False) trainer.predict( model=val_test_predict_qmodel, dataloaders=[torch.utils.data.DataLoader(RandomDataset(num_features, 16))] )