Skip to content

Commit

Permalink
Merge branch 'master' into chualan/fix-19658
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Dec 11, 2024
2 parents 5be022e + 60289d7 commit c20c173
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def on_before_zero_grad(self, optimizer):

model = CurrentTestModel()

trainer = Trainer(default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2)
trainer = Trainer(devices=1, default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2)
assert model.on_before_zero_grad_called == 0
trainer.fit(model)
assert max_steps == model.on_before_zero_grad_called
Expand Down Expand Up @@ -406,7 +406,7 @@ def prepare_data(self): ...
@pytest.mark.parametrize(
"kwargs",
[
{},
{"devices": 1},
# these precision plugins modify the optimization flow, so testing them explicitly
pytest.param({"accelerator": "gpu", "devices": 1, "precision": "16-mixed"}, marks=RunIf(min_cuda_gpus=1)),
pytest.param(
Expand Down Expand Up @@ -528,6 +528,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
# initial training to get a checkpoint
model = BoringModel()
trainer = Trainer(
devices=1,
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=2,
Expand All @@ -543,6 +544,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
callback = HookedCallback(called)
# already performed 1 step, resume and do 2 more
trainer = Trainer(
devices=1,
default_root_dir=tmp_path,
max_epochs=2,
limit_train_batches=2,
Expand Down Expand Up @@ -605,6 +607,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
# initial training to get a checkpoint
model = BoringModel()
trainer = Trainer(
devices=1,
default_root_dir=tmp_path,
max_steps=1,
limit_val_batches=0,
Expand All @@ -624,6 +627,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
train_batches = 2
steps_after_reload = 1 + train_batches
trainer = Trainer(
devices=1,
default_root_dir=tmp_path,
max_steps=steps_after_reload,
limit_val_batches=0,
Expand Down Expand Up @@ -690,6 +694,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
assert is_overridden(f"on_{noun}_model_train", model) == override_on_x_model_train
callback = HookedCallback(called)
trainer = Trainer(
devices=1,
default_root_dir=tmp_path,
max_epochs=1,
limit_val_batches=batches,
Expand Down Expand Up @@ -731,7 +736,11 @@ def test_trainer_model_hook_system_predict(tmp_path):
callback = HookedCallback(called)
batches = 2
trainer = Trainer(
default_root_dir=tmp_path, limit_predict_batches=batches, enable_progress_bar=False, callbacks=[callback]
devices=1,
default_root_dir=tmp_path,
limit_predict_batches=batches,
enable_progress_bar=False,
callbacks=[callback],
)
trainer.predict(model)
expected = [
Expand Down Expand Up @@ -797,7 +806,7 @@ def predict_dataloader(self):

model = CustomBoringModel()

trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=5)
trainer = Trainer(devices=1, default_root_dir=tmp_path, fast_dev_run=5)

trainer.fit(model)
trainer.test(model)
Expand All @@ -812,6 +821,7 @@ def test_trainer_datamodule_hook_system(tmp_path):
model = BoringModel()
batches = 2
trainer = Trainer(
devices=1,
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=batches,
Expand Down Expand Up @@ -887,7 +897,7 @@ class CustomHookedModel(HookedModel):
assert is_overridden("configure_model", model) == override_configure_model

datamodule = CustomHookedDataModule(ldm_called)
trainer = Trainer()
trainer = Trainer(devices=1)
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model, datamodule=datamodule)
ckpt_path = str(tmp_path / "file.ckpt")
Expand Down Expand Up @@ -960,6 +970,7 @@ def predict_step(self, *args, **kwargs):

model = MixedTrainModeModule()
trainer = Trainer(
devices=1,
default_root_dir=tmp_path,
max_epochs=1,
val_check_interval=1,
Expand Down

0 comments on commit c20c173

Please sign in to comment.