Skip to content

Commit

Permalink
Skip reconciliate_processes if used within a cluster environment that…
Browse files Browse the repository at this point in the history
… creates processes externally (#9389)

* [RFC] Skip reconciliate_processes if used within a cluster environment that creates processes externally
  • Loading branch information
ananthsub authored and carmocca committed Sep 21, 2021
1 parent 9f95a92 commit 0a8625b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.4.8] - 2021-09-21

- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389)
- Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)


## [1.4.7] - 2021-09-14

- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
Expand Down Expand Up @@ -34,6 +40,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347))


- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
29 changes: 22 additions & 7 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
self._ddp_comm_wrapper = ddp_comm_wrapper
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -235,6 +236,8 @@ def _call_children_scripts(self):
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

self._rank_0_has_called_call_children_scripts = True

def setup_distributed(self):
reset_seed()

Expand Down Expand Up @@ -331,7 +334,9 @@ def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Opt

def pre_dispatch(self):
# share ddp pids to all processes
self._share_information_to_prevent_deadlock()
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

# move the model to the correct device
self.model_to_device()
Expand Down Expand Up @@ -405,7 +410,16 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
find_unused_parameters=False,
)

def _share_information_to_prevent_deadlock(self):
def _should_run_deadlock_detection(self) -> bool:
"""Determines whether the plugin will perform process reconciliation in case of errors.
If the environment variable `PL_RECONCILE_PROCESS` is set, run detection regardless of the cluster environment.
By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler /
parent process to perform the process termination, external to Lightning.
"""
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_has_called_call_children_scripts

def _share_information_to_prevent_deadlock(self) -> None:
self._share_pids()

# there should be a unique sync_dir per nodes.
Expand All @@ -421,19 +435,20 @@ def _share_information_to_prevent_deadlock(self):

self._sync_dir = sync_dirs[self.node_rank]

def _share_pids(self):
"""
Make all DDP processes aware of all processes pids.
"""
def _share_pids(self) -> None:
"""Make all DDP processes aware of all processes pids."""
self.barrier()
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
pids = pids.cpu().numpy().tolist()
self._pids = pids if isinstance(pids, list) else [pids]

def reconciliate_processes(self, trace: str):
def reconciliate_processes(self, trace: str) -> None:
if self.world_size < 2:
return

if not self._should_run_deadlock_detection():
return

sync_dir = self._sync_dir

if not sync_dir:
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/environments/torch_elastic_deadlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from tests.helpers.boring_model import BoringModel

if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1":
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1" and os.getenv("PL_RECONCILE_PROCESS", "0") == "1":

class CustomException(Exception):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fi

# TODO: enable when CI uses torch>=1.9
# test deadlock is properly handled with TorchElastic.
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 PL_RECONCILE_PROCESS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
# if [ -z "$LOGS" ]; then
# exit 1
# fi
Expand Down
7 changes: 4 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,13 +1818,14 @@ def test_exception_when_lightning_module_is_not_set_on_trainer():
trainer.predict()


class CustomException(Exception):
pass


@RunIf(min_gpus=2, special=True)
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
"""Test that DDP kills the remaining processes when only one rank is throwing an exception."""

class CustomException(Exception):
pass

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if batch_idx == 1 and self.trainer.is_global_zero:
Expand Down

0 comments on commit 0a8625b

Please sign in to comment.