Skip to content

Commit

Permalink
Merge branch 'config_val' of github.com:daniellepintz/pytorch-lightni…
Browse files Browse the repository at this point in the history
…ng into config_val
  • Loading branch information
daniellepintz committed Oct 18, 2021
2 parents 72c3001 + a5eb6d9 commit 2eebf4c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,13 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) ->
has_step = is_overridden(step_name, model)

if has_loader and not has_step:
raise MisconfigurationException(f"You defined `{loader_name}` but have not defined `{step_name}` on your Lightning Module.")
raise MisconfigurationException(
f"You defined `{loader_name}` but have not defined `{step_name}` on your Lightning Module."
)
if has_step and not has_loader:
raise MisconfigurationException(f"You defined `{step_name}` but have not defined `{loader_name}` on your Lightning Module.")
raise MisconfigurationException(
f"You defined `{step_name}` but have not defined `{loader_name}` on your Lightning Module."
)

# ----------------------------------------------
# verify model does not have
Expand Down
30 changes: 24 additions & 6 deletions tests/trainer/test_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,19 @@ def test_fit_val_loop_config(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

# has val data but not val loop
with pytest.raises(MisconfigurationException, match=r"You defined `val_dataloader` but have not defined `validation_step` on your Lightning Module."):
with pytest.raises(
MisconfigurationException,
match=r"You defined `val_dataloader` but have not defined `validation_step` on your Lightning Module.",
):
model = BoringModel()
model.validation_step = None
trainer.fit(model)

# has val loop but no val data
with pytest.raises(MisconfigurationException, match=r"You defined `validation_step` but have not defined `val_dataloader` on your Lightning Module."):
with pytest.raises(
MisconfigurationException,
match=r"You defined `validation_step` but have not defined `val_dataloader` on your Lightning Module.",
):
model = BoringModel()
model.val_dataloader = None
trainer.fit(model)
Expand All @@ -69,13 +75,19 @@ def test_test_loop_config(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

# has test loop but no test data
with pytest.raises(MisconfigurationException, match=r"You defined `test_step` but have not defined `test_dataloader` on your Lightning Module"):
with pytest.raises(
MisconfigurationException,
match=r"You defined `test_step` but have not defined `test_dataloader` on your Lightning Module",
):
model = BoringModel()
model.test_dataloader = None
trainer.test(model)

# has test data but no test loop
with pytest.raises(MisconfigurationException, match=r"You defined `test_dataloader` but have not defined `test_step` on your Lightning Module"):
with pytest.raises(
MisconfigurationException,
match=r"You defined `test_dataloader` but have not defined `test_step` on your Lightning Module",
):
model = BoringModel()
model.test_step = None
trainer.test(model)
Expand All @@ -86,13 +98,19 @@ def test_val_loop_config(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

# has val loop but no val data
with pytest.raises(MisconfigurationException, match=r"You defined `validation_step` but have not defined `val_dataloader` on your Lightning Module"):
with pytest.raises(
MisconfigurationException,
match=r"You defined `validation_step` but have not defined `val_dataloader` on your Lightning Module",
):
model = BoringModel()
model.val_dataloader = None
trainer.validate(model)

# has val data but no val loop
with pytest.raises(MisconfigurationException, match=r"You defined `val_dataloader` but have not defined `validation_step` on your Lightning Module"):
with pytest.raises(
MisconfigurationException,
match=r"You defined `val_dataloader` but have not defined `validation_step` on your Lightning Module",
):
model = BoringModel()
model.validation_step = None
trainer.validate(model)
Expand Down

0 comments on commit 2eebf4c

Please sign in to comment.