Skip to content

Commit

Permalink
Improve code quality in AcceleratorConnector._configure_slurm_ddp (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 17, 2021
1 parent 0fa07da commit 1ff35ed
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 40 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class SLURMEnvironment(ClusterEnvironment):
def creates_processes_externally(self) -> bool:
return True

@staticmethod
def detect() -> bool:
"""Returns ``True`` if the current process was launched on a SLURM cluster."""
return "SLURM_NTASKS" in os.environ

@property
def main_address(self) -> str:
# figure out the root node addr
Expand Down
58 changes: 25 additions & 33 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
self.precision = precision
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
self._is_slurm_managing_tasks = False

self._precision_plugin: Optional[PrecisionPlugin] = None
self._training_type_plugin: Optional[TrainingTypePlugin] = None
Expand Down Expand Up @@ -167,7 +166,6 @@ def __init__(
self.handle_given_plugins()
self._set_distrib_type_if_training_type_plugin_passed()

self._configure_slurm_ddp()
self._cluster_environment = self.select_cluster_environment()

self.update_device_type_if_ipu_plugin()
Expand Down Expand Up @@ -703,15 +701,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks()
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
use_ddp_spawn = self._distrib_type == _StrategyType.DDP_SPAWN
use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
use_tpu_spawn = self.use_tpu and self._distrib_type == _StrategyType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks()
use_ddp_sharded = self._distrib_type == _StrategyType.DDP_SHARDED
use_ddp_sharded_spawn = self._distrib_type == _StrategyType.DDP_SHARDED_SPAWN
use_ddp_fully_sharded = self._distrib_type == _StrategyType.DDP_FULLY_SHARDED
Expand Down Expand Up @@ -807,8 +805,9 @@ def select_accelerator(self) -> Accelerator:
def select_cluster_environment(self) -> ClusterEnvironment:
if self._cluster_environment is not None:
return self._cluster_environment
if self._is_slurm_managing_tasks:
if self._is_slurm_managing_tasks():
env = SLURMEnvironment()
rank_zero_info("Multiprocessing is handled by SLURM.")
elif TorchElasticEnvironment.is_using_torchelastic():
env = TorchElasticEnvironment()
elif KubeflowEnvironment.is_using_kubeflow():
Expand Down Expand Up @@ -990,34 +989,6 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
elif self.has_gpu:
self._device_type = DeviceType.GPU

def _configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
if self.use_ddp or self.use_ddp2:
num_requested_gpus = self.num_gpus * self.num_nodes
num_slurm_tasks = 0
try:
num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus

# enable slurm cpu
if num_requested_gpus == 0:
self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes

# in interactive mode we don't manage tasks
job_name = os.environ["SLURM_JOB_NAME"]
if job_name == "bash":
self._is_slurm_managing_tasks = False

except Exception:
# likely not on slurm, so set the slurm managed flag to false
self._is_slurm_managing_tasks = False

# notify user the that slurm is managing tasks
if self._is_slurm_managing_tasks:
rank_zero_info("Multi-processing is handled by Slurm.")

def _set_distrib_type_if_training_type_plugin_passed(self):
# This is required as when `TrainingTypePlugin` instance is passed to either `strategy`
# or `plugins` flag, `AcceleratorConnector.set_distributed_mode` is not required to be
Expand All @@ -1026,3 +997,24 @@ def _set_distrib_type_if_training_type_plugin_passed(self):
return
if self._training_type_plugin is not None:
self._distrib_type = getattr(self._training_type_plugin, "distributed_backend", None)

def _is_slurm_managing_tasks(self) -> bool:
"""Returns whether we let SLURM manage the processes or not.
Returns ``True`` if and only if these conditions match:
- A SLURM cluster is detected
- A distributed plugin is being used
- The process is not launching in interactive mode
- The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer
"""
if (
(not self.use_ddp and not self.use_ddp2)
or not SLURMEnvironment.detect()
or os.environ.get("SLURM_JOB_NAME") == "bash" # in interactive mode we don't manage tasks
):
return False

total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes
num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
return num_slurm_tasks == total_requested_devices
12 changes: 6 additions & 6 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_accelerator_choice_ddp_slurm(set_device_mock, device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -136,7 +136,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -323,7 +323,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -791,7 +791,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -824,7 +824,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock, strategy):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -1008,7 +1008,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock, strategy):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def test_dp_resume(tmpdir):

# fit model
trainer = Trainer(**trainer_options)
trainer._is_slurm_managing_tasks = True
trainer.fit(model, datamodule=dm)

# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
Expand Down

0 comments on commit 1ff35ed

Please sign in to comment.