Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise MisconfigurationException if trainer.eval is missing required methods #10016

Merged
merged 18 commits into from
Oct 26, 2021
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))


- Raise `MisconfigurationException` instead of warning if `trainer.{validate/test}` is missing required methods ([#10016](https://github.com/PyTorchLightning/pytorch-lightning/pull/10016))
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


### Deprecated

Expand Down Expand Up @@ -397,6 +399,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
108 changes: 55 additions & 53 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule

"""
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
__verify_train_loop_configuration(trainer, model)
__verify_eval_loop_configuration(model, "val")
__verify_train_val_loop_configuration(trainer, model)
__verify_manual_optimization_support(trainer, model)
__check_training_step_requires_dataloader_iter(model)
elif trainer.state.fn == TrainerFn.VALIDATING:
__verify_eval_loop_configuration(model, "val")
__verify_eval_loop_configuration(trainer, model, "val")
elif trainer.state.fn == TrainerFn.TESTING:
__verify_eval_loop_configuration(model, "test")
__verify_eval_loop_configuration(trainer, model, "test")
elif trainer.state.fn == TrainerFn.PREDICTING:
__verify_predict_loop_configuration(trainer, model)
__verify_eval_loop_configuration(trainer, model, "predict")

__verify_dp_batch_transfer_support(trainer, model)
_check_add_get_queue(model)
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
Expand All @@ -51,7 +51,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
_check_dl_idx_in_on_train_batch_hooks(trainer, model)


def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
# -----------------------------------
# verify model has a training step
# -----------------------------------
Expand Down Expand Up @@ -83,24 +83,15 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
)

# ----------------------------------------------
# verify model does not have
# - on_train_dataloader
# - on_val_dataloader
# verify model does not have on_train_dataloader
# ----------------------------------------------
has_on_train_dataloader = is_overridden("on_train_dataloader", model)
if has_on_train_dataloader:
rank_zero_deprecation(
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
"Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
" Please use `train_dataloader()` directly."
)

has_on_val_dataloader = is_overridden("on_val_dataloader", model)
if has_on_val_dataloader:
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)

trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
trainer.overriden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model)
automatic_optimization = model.automatic_optimization
Expand All @@ -110,8 +101,30 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
rank_zero_warn(
"When using `Trainer(accumulate_grad_batches != 1)` and overriding"
"`LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
"(rather, they are called on every optimization step)."
" `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
" (rather, they are called on every optimization step)."
)

# -----------------------------------
# verify model for val loop
# -----------------------------------

has_val_loader = trainer.data_connector._val_dataloader_source.is_defined()
has_val_step = is_overridden("validation_step", model)

if has_val_loader and not has_val_step:
rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
if has_val_step and not has_val_loader:
rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")

daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
# ----------------------------------------------
# verify model does not have on_val_dataloader
# ----------------------------------------------
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
if has_on_val_dataloader:
rank_zero_deprecation(
"Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)


Expand Down Expand Up @@ -143,52 +156,41 @@ def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
)


def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> None:
def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None:
loader_name = f"{stage}_dataloader"
step_name = "validation_step" if stage == "val" else "test_step"
step_name = "validation_step" if stage == "val" else f"{stage}_step"
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
trainer_method = "validate" if stage == "val" else stage
on_eval_hook = f"on_{loader_name}"

has_loader = is_overridden(loader_name, model)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
has_loader = getattr(trainer.data_connector, f"_{stage}_dataloader_source").is_defined()
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
has_step = is_overridden(step_name, model)

if has_loader and not has_step:
rank_zero_warn(f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop")
if has_step and not has_loader:
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
has_on_eval_dataloader = is_overridden(on_eval_hook, model)

# ----------------------------------------------
# verify model does not have
# - on_val_dataloader
# - on_test_dataloader
# verify model does not have on_eval_dataloader
# ----------------------------------------------
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
if has_on_val_dataloader:
if has_on_eval_dataloader:
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will"
f" be removed in v1.7.0. Please use `{loader_name}()` directly."
)

has_on_test_dataloader = is_overridden("on_test_dataloader", model)
if has_on_test_dataloader:
rank_zero_deprecation(
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `test_dataloader()` directly."
)
# -----------------------------------
# verify model has an eval_dataloader
# -----------------------------------
if not has_loader:
raise MisconfigurationException(f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`.")

# predict_step is not required to be overridden
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if stage == "predict":
return

def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
has_predict_dataloader = trainer.data_connector._predict_dataloader_source.is_defined()
if not has_predict_dataloader:
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
# ----------------------------------------------
# verify model does not have
# - on_predict_dataloader
# ----------------------------------------------
has_on_predict_dataloader = is_overridden("on_predict_dataloader", model)
if has_on_predict_dataloader:
rank_zero_deprecation(
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `predict_dataloader()` directly."
)
# -----------------------------------
# verify model has an eval_step
# -----------------------------------
if not has_step:
raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.")


def __verify_dp_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
Expand Down
2 changes: 0 additions & 2 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@


class TestModel(BoringModel):
test_step = None

def __init__(self):
super().__init__()
self.layer = Sequential(
Expand Down
10 changes: 5 additions & 5 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,27 +162,27 @@ def _run(model, task="fit"):
model = CustomBoringModel()

with pytest.deprecated_call(
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
match="Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "fit")

with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "fit")

with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "validate")

with pytest.deprecated_call(
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
match="Method `on_test_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "test")

with pytest.deprecated_call(
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
match="Method `on_predict_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "predict")

Expand Down
4 changes: 3 additions & 1 deletion tests/helpers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def test_models(tmpdir, data_class, model_class):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)

if dm is not None:
trainer.test(model, datamodule=dm)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

model.to_torchscript()
if data_class:
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,11 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):

# correct result and ok accuracy
assert trainer.state.finished, f"Training failed with {trainer.state}"
pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
pretrained_model = CustomClassificationModelDP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

# run test set
new_trainer = Trainer(**trainer_options)
new_trainer.test(pretrained_model)
new_trainer.test(pretrained_model, datamodule=dm)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
pretrained_model.cpu()

dataloaders = dm.test_dataloader()
Expand Down Expand Up @@ -385,7 +385,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):

# run test set
new_trainer = Trainer(**trainer_options)
new_trainer.test(pretrained_model)
new_trainer.test(pretrained_model, datamodule=dm)
pretrained_model.cpu()

dataloaders = dm.test_dataloader()
Expand Down
5 changes: 3 additions & 2 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,9 @@ def test_dataloader(self):
gpus=1,
fast_dev_run=True,
)
trainer.fit(model, datamodule=TestSetupIsCalledDataModule())
trainer.test(model)
dm = TestSetupIsCalledDataModule()
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)


@mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True)
Expand Down
50 changes: 23 additions & 27 deletions tests/trainer/test_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,50 +52,51 @@ def test_fit_val_loop_config(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

# no val data has val loop
with pytest.warns(UserWarning, match=r"you passed in a val_dataloader but have no validation_step"):
with pytest.warns(UserWarning, match=r"You passed in a `val_dataloader` but have no `validation_step`"):
model = BoringModel()
model.validation_step = None
trainer.fit(model)

# has val loop but no val data
with pytest.warns(UserWarning, match=r"you defined a validation_step but have no val_dataloader"):
with pytest.warns(UserWarning, match=r"You defined a `validation_step` but have no `val_dataloader`"):
model = BoringModel()
model.val_dataloader = None
trainer.fit(model)


def test_test_loop_config(tmpdir):
"""When either test loop or test data are missing."""
def test_eval_loop_config(tmpdir):
"""When either eval step or eval data is missing."""
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

# has val step but no val data
with pytest.raises(MisconfigurationException, match=r"No `val_dataloader\(\)` method defined"):
model = BoringModel()
model.val_dataloader = None
trainer.validate(model)

# has test data but no val step
with pytest.raises(MisconfigurationException, match=r"No `validation_step\(\)` method defined"):
model = BoringModel()
model.validation_step = None
trainer.validate(model)

# has test loop but no test data
with pytest.warns(UserWarning, match=r"you defined a test_step but have no test_dataloader"):
with pytest.raises(MisconfigurationException, match=r"No `test_dataloader\(\)` method defined"):
model = BoringModel()
model.test_dataloader = None
trainer.test(model)

# has test data but no test loop
with pytest.warns(UserWarning, match=r"you passed in a test_dataloader but have no test_step"):
# has test data but no test step
with pytest.raises(MisconfigurationException, match=r"No `test_step\(\)` method defined"):
model = BoringModel()
model.test_step = None
trainer.test(model)


def test_val_loop_config(tmpdir):
"""When either validation loop or validation data are missing."""
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

# has val loop but no val data
with pytest.warns(UserWarning, match=r"you defined a validation_step but have no val_dataloader"):
model = BoringModel()
model.val_dataloader = None
trainer.validate(model)

# has val data but no val loop
with pytest.warns(UserWarning, match=r"you passed in a val_dataloader but have no validation_step"):
# has predict step but no predict data
with pytest.raises(MisconfigurationException, match=r"No `predict_dataloader\(\)` method defined"):
model = BoringModel()
model.validation_step = None
trainer.validate(model)
model.predict_dataloader = None
trainer.predict(model)


@pytest.mark.parametrize("datamodule", [False, True])
Expand Down Expand Up @@ -130,11 +131,6 @@ def predict_dataloader(self):
assert len(results) == 2
assert results[0][0].shape == torch.Size([1, 2])

model.predict_dataloader = None

with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
trainer.predict(model)


def test_trainer_manual_optimization_config(tmpdir):
"""Test error message when requesting Trainer features unsupported with manual optimization."""
Expand Down