Skip to content

Commit

Permalink
internal distributed_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Oct 19, 2021
1 parent e7c1e75 commit 328f9e5
Showing 1 changed file with 37 additions and 33 deletions.
70 changes: 37 additions & 33 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def __init__(
self._accelerator_type = None

self.strategy = strategy.lower() if isinstance(strategy, str) else strategy
self.accelerator = accelerator
# TODO: Rename this to something else once all the distributed flags are moved to strategy
self.distributed_backend = accelerator

self._init_deterministic(deterministic)

Expand Down Expand Up @@ -202,7 +203,7 @@ def _init_deterministic(self, deterministic: bool) -> None:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def select_accelerator_type(self) -> None:
if self.accelerator == "auto":
if self.distributed_backend == "auto":
if self.has_tpu:
self._accelerator_type = DeviceType.TPU
elif self.has_ipu:
Expand All @@ -212,34 +213,34 @@ def select_accelerator_type(self) -> None:
else:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU
elif self.accelerator == DeviceType.TPU:
elif self.distributed_backend == DeviceType.TPU:
if not self.has_tpu:
msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`"
raise MisconfigurationException(f"You passed `accelerator='tpu'`, but {msg}.")
self._accelerator_type = DeviceType.TPU
elif self.accelerator == DeviceType.IPU:
elif self.distributed_backend == DeviceType.IPU:
if not self.has_ipu:
msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`"
raise MisconfigurationException(f"You passed `accelerator='ipu'`, but {msg}.")
self._accelerator_type = DeviceType.IPU
elif self.accelerator == DeviceType.GPU:
elif self.distributed_backend == DeviceType.GPU:
if not self.has_gpu:
msg = "you didn't pass `gpus` to `Trainer`" if torch.cuda.is_available() else "GPUs are not available"
raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.")
self._accelerator_type = DeviceType.GPU
elif self.accelerator == DeviceType.CPU:
elif self.distributed_backend == DeviceType.CPU:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU

if self.accelerator in self.accelerator_types:
self.accelerator = None
if self.distributed_backend in self.accelerator_types:
self.distributed_backend = None

def _validate_accelerator_and_devices(self) -> None:
if self.accelerator not in self.accelerator_types and self.devices is not None:
if self.distributed_backend not in self.accelerator_types and self.devices is not None:
raise MisconfigurationException(
f"You passed `devices={self.devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
f" got `accelerator={self.accelerator!r}`."
f" got `accelerator={self.distributed_backend!r}`."
)

def _validate_accelerator_type(self) -> None:
Expand All @@ -255,16 +256,16 @@ def _warn_if_devices_flag_ignored(self) -> None:
if self.devices is None:
return
devices_warning = f"The flag `devices={self.devices}` will be ignored, as you have set"
if self.accelerator in ("auto", DeviceType.TPU):
if self.distributed_backend in ("auto", DeviceType.TPU):
if self.tpu_cores is not None:
rank_zero_warn(f"{devices_warning} `tpu_cores={self.tpu_cores}`")
elif self.accelerator in ("auto", DeviceType.IPU):
elif self.distributed_backend in ("auto", DeviceType.IPU):
if self.ipus is not None:
rank_zero_warn(f"{devices_warning} `ipus={self.ipus}`")
elif self.accelerator in ("auto", DeviceType.GPU):
elif self.distributed_backend in ("auto", DeviceType.GPU):
if self.gpus is not None:
rank_zero_warn(f"{devices_warning} `gpus={self.gpus}`")
elif self.accelerator in ("auto", DeviceType.CPU):
elif self.distributed_backend in ("auto", DeviceType.CPU):
if self.num_processes != 1:
rank_zero_warn(f"{devices_warning} `num_processes={self.num_processes}`")

Expand All @@ -281,15 +282,15 @@ def _set_devices_if_none(self) -> None:
self.devices = self.num_processes

def _handle_accelerator_and_strategy(self) -> None:
if self.accelerator is not None and self.accelerator in list(DistributedType):
if self.distributed_backend is not None and self.distributed_backend in list(DistributedType):
rank_zero_deprecation(
f"Passing `Trainer(accelerator={self.accelerator!r})` has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={self.accelerator!r})` instead."
f"Passing `Trainer(accelerator={self.distributed_backend!r})` has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={self.distributed_backend!r})` instead."
)
if self.strategy is not None:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy!r})` but have"
f" also passed `Trainer(accelerator={self.accelerator!r})`."
f" also passed `Trainer(accelerator={self.distributed_backend!r})`."
f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead."
)

Expand Down Expand Up @@ -635,8 +636,11 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return ApexMixedPrecisionPlugin(self.amp_level)

def select_training_type_plugin(self) -> TrainingTypePlugin:
if isinstance(self.accelerator, Accelerator) and self.accelerator.training_type_plugin is not None:
plugin = self.accelerator.training_type_plugin
if (
isinstance(self.distributed_backend, Accelerator)
and self.distributed_backend.training_type_plugin is not None
):
plugin = self.distributed_backend.training_type_plugin
elif self.use_ddp2:
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment)
elif self.use_ddp and self.use_deepspeed:
Expand Down Expand Up @@ -718,15 +722,15 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
return training_type

def select_accelerator(self) -> Accelerator:
if isinstance(self.accelerator, Accelerator):
if isinstance(self.distributed_backend, Accelerator):
# custom accelerator from user
if self._precision_plugin is not None or self._training_type_plugin is not None:
# plugins also specified by user
rank_zero_warn(
"Specified `Precision` and `TrainingType` plugins will be ignored,"
" since an `Accelerator` instance was provided."
)
return self.accelerator
return self.distributed_backend

if self.use_gpu:
acc_cls = GPUAccelerator
Expand Down Expand Up @@ -766,32 +770,32 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
return

if strategy is not None and strategy in TrainingTypePluginsRegistry:
self.accelerator = TrainingTypePluginsRegistry[strategy]["distributed_backend"]
self.distributed_backend = TrainingTypePluginsRegistry[strategy]["distributed_backend"]
elif strategy is not None:
self.accelerator = strategy
self.distributed_backend = strategy

if isinstance(self.accelerator, Accelerator):
if isinstance(self.distributed_backend, Accelerator):
return

is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == DeviceType.CPU
_use_cpu = is_cpu_accelerator_type or self.accelerator and "cpu" in self.accelerator
_use_cpu = is_cpu_accelerator_type or self.distributed_backend and "cpu" in self.distributed_backend

if self.accelerator is None:
if self.distributed_backend is None:
if self.has_horovodrun():
self._set_horovod_backend()
elif self.num_gpus == 0 and self.num_nodes > 1:
self._distrib_type = DistributedType.DDP
elif self.num_gpus == 0 and self.num_processes > 1:
self.accelerator = DistributedType.DDP_SPAWN
self.distributed_backend = DistributedType.DDP_SPAWN
elif self.num_gpus > 1 and not _use_cpu:
rank_zero_warn(
"You requested multiple GPUs but did not specify a backend, e.g."
' `Trainer(strategy="dp"|"ddp"|"ddp2")`. Setting `strategy="ddp_spawn"` for you.'
)
self.accelerator = DistributedType.DDP_SPAWN
self.distributed_backend = DistributedType.DDP_SPAWN

# special case with DDP on CPUs
if self.accelerator == DistributedType.DDP_CPU:
if self.distributed_backend == DistributedType.DDP_CPU:
if _TPU_AVAILABLE:
raise MisconfigurationException(
"`accelerator='ddp_cpu'` is not supported on TPU machines. "
Expand All @@ -816,8 +820,8 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
self._distrib_type = DistributedType.TPU_SPAWN
elif self.has_ipu and not _use_cpu:
self._device_type = DeviceType.IPU
elif self.accelerator and self._distrib_type is None:
self._distrib_type = DistributedType(self.accelerator)
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)

if self.num_gpus > 0 and not _use_cpu:
self._device_type = DeviceType.GPU
Expand Down Expand Up @@ -850,7 +854,7 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
self.num_processes = self.num_nodes

# Horovod is an extra case...
if self.accelerator == DistributedType.HOROVOD:
if self.distributed_backend == DistributedType.HOROVOD:
self._set_horovod_backend()

using_valid_distributed = self.use_ddp or self.use_ddp2
Expand Down

0 comments on commit 328f9e5

Please sign in to comment.