Skip to content

Commit

Permalink
Fix root GPU property (#5908)
Browse files Browse the repository at this point in the history
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator

* Add missing tests back
  • Loading branch information
SeanNaren authored Feb 10, 2021
1 parent 8f3947b commit 50ecc4a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def __init__(
self.gpus = pick_multiple_gpus(gpus)

self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
self.root_gpu = device_parser.determine_root_gpu_device(self.parallel_device_ids)

self.set_distributed_mode()
self.configure_slurm_ddp()
Expand Down Expand Up @@ -276,6 +275,10 @@ def parallel_devices(self):
devices = [torch.device("cpu")] * self.num_processes
return devices

@property
def root_gpu(self) -> int:
return self.accelerator.root_device.index

@property
def is_using_torchelastic(self):
te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ)
Expand Down Expand Up @@ -375,7 +378,8 @@ def select_training_type_plugin(self):
elif self.on_tpu:
plugin = SingleTPUPlugin(self.tpu_id)
else:
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
return plugin

def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
Expand Down Expand Up @@ -525,7 +529,6 @@ def _set_horovod_backend(self):
if self.on_gpu:
# Horovod assigns one local GPU per process
self.parallel_device_ids = list(range(hvd.local_size()))
self.root_gpu = hvd.local_rank()
else:
self.num_processes = hvd.local_size()

Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def mocked_device_count(monkeypatch):
def device_count():
return PRETEND_N_OF_GPUS

def is_available():
return True

monkeypatch.setattr(torch.cuda, 'is_available', is_available)
monkeypatch.setattr(torch.cuda, 'device_count', device_count)


Expand Down

0 comments on commit 50ecc4a

Please sign in to comment.