Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DuYicong515 committed Mar 1, 2022
1 parent 4fc502a commit a5ce5a7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
15 changes: 9 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
SimpleProfiler,
XLAProfiler,
)
from pytorch_lightning.strategies import ParallelStrategy, SingleDeviceStrategy, Strategy
from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
Expand Down Expand Up @@ -2028,11 +2028,14 @@ def num_nodes(self) -> int:

@property
def device_ids(self) -> List[int]:
if isinstance(self.strategy, ParallelStrategy):
return [torch._utils._get_device_index(device, allow_cpu=True) for device in self.strategy.parallel_devices]
elif isinstance(self.strategy, SingleDeviceStrategy):
return [torch._utils._get_device_index(self.strategy.root_device, allow_cpu=True)]
return []
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
device_ids = []
for idx, device in enumerate(devices):
if isinstance(device, torch.device):
device_ids.append(device.index or idx)
elif isinstance(device, int):
device_ids.append(device)
return device_ids

@property
def num_devices(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,7 +2151,7 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st
@pytest.mark.parametrize(
["trainer_kwargs", "expected_device_ids"],
[
({"strategy": None}, []),
({"strategy": None}, [0]),
({"num_processes": 1}, [0]),
({"gpus": 1}, [0]),
({"devices": 1}, [0]),
Expand All @@ -2166,4 +2166,4 @@ def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
trainer = Trainer(**trainer_kwargs)
trainer.num_devices = expected_device_ids
assert trainer.device_ids == expected_device_ids

0 comments on commit a5ce5a7

Please sign in to comment.