Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
DuYicong515 committed Mar 1, 2022
1 parent a5ce5a7 commit 05ca9c5
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,16 +2153,16 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st
[
({"strategy": None}, [0]),
({"num_processes": 1}, [0]),
({"gpus": 1}, [0]),
({"devices": 1}, [0]),
({"accelerator": "gpu", "devices": 1}, [0]),
({"strategy": "ddp", "devices": 1}, [0]),
({"strategy": "ddp", "gpus": 2}, [0, 1]),
({"strategy": "ddp", "num_processes": 2}, [0, 1]),
({"strategy": "ddp", "gpus": [0, 2]}, [0, 2]),
({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, [0, 1]),
({"strategy": "ddp", "devices": 2}, [0, 1]),
({"strategy": "ddp", "accelerator": "gpu", "devices": [0, 2]}, [0, 2]),
],
)
def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids):
if trainer_kwargs.get("gpus") is not None:
if trainer_kwargs.get("accelerator") == "gpu":
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
trainer = Trainer(**trainer_kwargs)
Expand Down

0 comments on commit 05ca9c5

Please sign in to comment.