Skip to content

Commit

Permalink
Use non-deprecated options in tests (#9949)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 15, 2021
1 parent db4e770 commit e973bcb
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 15 deletions.
8 changes: 4 additions & 4 deletions tests/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)

Expand Down Expand Up @@ -75,7 +75,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)

Expand Down Expand Up @@ -104,7 +104,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
log_every_n_steps=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)

Expand All @@ -122,7 +122,7 @@ def test_device_stats_monitor_no_logger(tmpdir):
callbacks=[device_stats],
max_epochs=1,
logger=False,
checkpoint_callback=False,
enable_checkpointing=False,
enable_progress_bar=False,
)

Expand Down
7 changes: 4 additions & 3 deletions tests/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_model_summary_callback_present_trainer():


def test_model_summary_callback_with_weights_summary_none():

trainer = Trainer(weights_summary=None)
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
trainer = Trainer(weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False)
Expand All @@ -42,7 +42,8 @@ def test_model_summary_callback_with_weights_summary_none():
trainer = Trainer(enable_model_summary=False, weights_summary="full")
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=True, weights_summary=None)
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
trainer = Trainer(enable_model_summary=True, weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)


Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def configure_optimizers(self):
state_dict = torch.load(resume_ckpt)

trainer_args.update(
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "checkpoint_callback": False, "callbacks": []}
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "enable_checkpointing": False, "callbacks": []}
)

class CustomClassifModel(CustomClassifModel):
Expand Down
12 changes: 6 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):
"max_epochs": 3,
"max_steps": num_train_samples + 10,
"logger": False,
"weights_summary": None,
"enable_model_summary": False,
"enable_progress_bar": False,
}
trainer = Trainer(**trainer_kwargs)
Expand Down Expand Up @@ -555,7 +555,7 @@ def test_trainer_min_steps_and_epochs(tmpdir):
# define less min steps than 1 epoch
"min_steps": num_train_samples // 2,
"logger": False,
"weights_summary": None,
"enable_model_summary": False,
"enable_progress_bar": False,
}
trainer = Trainer(**trainer_kwargs)
Expand Down Expand Up @@ -723,9 +723,9 @@ def predict_step(self, batch, *_):
assert getattr(trainer, path_attr) == ckpt_path


@pytest.mark.parametrize("checkpoint_callback", (False, True))
@pytest.mark.parametrize("enable_checkpointing", (False, True))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)
Expand All @@ -746,15 +746,15 @@ def predict_step(self, batch, *_):
limit_predict_batches=1,
enable_progress_bar=False,
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
enable_checkpointing=enable_checkpointing,
)
trainer.fit(model)

trainer_fn = getattr(trainer, fn)
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
assert getattr(trainer, path_attr) is None

if checkpoint_callback:
if enable_checkpointing:
trainer_fn(ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path

Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
"fit": {
"model": {"class_path": "tests.helpers.BoringModel"},
"data": {"class_path": "tests.helpers.BoringDataModule", "init_args": {"data_dir": str(tmpdir)}},
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "weights_summary": None},
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False},
}
}
config_path = tmpdir / "config.yaml"
Expand Down

0 comments on commit e973bcb

Please sign in to comment.