Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip reconciliate_processes if used within a cluster environment that creates processes externally #9389

Merged
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512))


- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
17 changes: 16 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
self._model_averaging_period = model_averaging_period
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._has_called_call_children_scripts: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -252,6 +253,8 @@ def _call_children_scripts(self):
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

self._has_called_call_children_scripts = True

def setup_distributed(self):
reset_seed()

Expand Down Expand Up @@ -454,7 +457,14 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
find_unused_parameters=False,
)

def _share_information_to_prevent_deadlock(self):
def _share_information_to_prevent_deadlock(self) -> None:
self._has_called_call_children_scripts = self.broadcast(self._has_called_call_children_scripts)

# Short-circuit debug info set for process reconciliation if processes
# are managed by a scheduler or parent process external to Lightning.
if not self._has_called_call_children_scripts:
return

self._share_pids()

# there should be a unique sync_dir per nodes.
Expand All @@ -481,6 +491,11 @@ def reconciliate_processes(self, trace: str):
if self.world_size < 2:
return

# If the cluster environment creates the process, allow the scheduler / parent process
# to perform the process termination external to Lightning.
if not self._has_called_call_children_scripts:
return

sync_dir = self._sync_dir

if not sync_dir:
Expand Down
37 changes: 0 additions & 37 deletions tests/plugins/environments/torch_elastic_deadlock.py

This file was deleted.

8 changes: 0 additions & 8 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ if [ $? -eq 0 ]; then
report+="Ran\ttests/utilities/test_warnings.py\n"
fi

# TODO: enable when CI uses torch>=1.9
# test deadlock is properly handled with TorchElastic.
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
# if [ -z "$LOGS" ]; then
# exit 1
# fi
# report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n"

# test that a user can manually launch individual processes
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1"
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} &
Expand Down
7 changes: 4 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,13 +1823,14 @@ def test_exception_when_lightning_module_is_not_set_on_trainer():
trainer.predict()


class CustomException(Exception):
pass


@RunIf(min_gpus=2, special=True)
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
"""Test that DDP kills the remaining processes when only one rank is throwing an exception."""

class CustomException(Exception):
pass

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if batch_idx == 1 and self.trainer.is_global_zero:
Expand Down