diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py index 0eb8f6b0be9b7..247a57ec230df 100644 --- a/pytorch_lightning/accelerators/base_backend.py +++ b/pytorch_lightning/accelerators/base_backend.py @@ -162,3 +162,6 @@ def _clip_gradients(self, optimizer): def on_train_epoch_end(self): pass + + def early_stopping_should_stop(self, pl_module): + return self.trainer.should_stop diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 8eabaed3b3c49..401844f32c60b 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -23,6 +23,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.accelerators.base_backend import Accelerator import torch.distributed as torch_distrib +import torch.distributed as dist try: from hydra.utils import to_absolute_path, get_original_cwd @@ -199,3 +200,10 @@ def test_step_end(self, output): def barrier(self, name: str = None): torch_distrib.barrier() + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) + dist.all_reduce(stop, op=dist.reduce_op.SUM) + dist.barrier() + should_stop = stop == self.trainer.world_size + return should_stop diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 2064b24fcbcc4..677f8e31dcb8b 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -27,6 +27,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port from pytorch_lightning.accelerators.base_backend import Accelerator import torch.distributed as torch_distrib +import torch.distributed as dist try: from hydra.utils import to_absolute_path, get_original_cwd @@ -278,3 +279,10 @@ def _check_can_spawn_children(self): def barrier(self, name: str = None): torch_distrib.barrier() + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) + dist.all_reduce(stop, op=dist.reduce_op.SUM) + dist.barrier() + should_stop = stop == self.trainer.world_size + return should_stop diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index c3bbc603422ef..1638a0b976f26 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port from pytorch_lightning.accelerators.base_backend import Accelerator import torch.distributed as torch_distrib +import torch.distributed as dist try: from apex import amp @@ -195,3 +196,10 @@ def test_step(self, args): def barrier(self, name: str = None): torch_distrib.barrier() + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) + dist.all_reduce(stop, op=dist.reduce_op.SUM) + dist.barrier() + should_stop = stop == self.trainer.world_size + return should_stop diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index 2be4be31bead1..2ac69d1d518d9 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -253,3 +253,10 @@ def clip_gradients(self, optimizer): def barrier(self, name: str = None): torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32) + stop = xm.mesh_reduce("stop_signal", stop, sum) + torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") + should_stop = int(stop.item()) == self.trainer.world_size + return should_stop diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 9bb7f82676c34..eebebc3bc62ce 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -21,7 +21,6 @@ """ import numpy as np import torch -import torch.distributed as dist from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback @@ -216,22 +215,8 @@ def _run_early_stopping_check(self, trainer, pl_module): trainer.should_stop = True # stop every ddp process if any world process decides to stop - self._stop_distributed_training(trainer, pl_module) - - def _stop_distributed_training(self, trainer, pl_module): - - # in ddp make sure all processes stop when one is flagged - if trainer.use_ddp or trainer.use_ddp2: - stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() - trainer.should_stop = stop == trainer.world_size - - if trainer.use_tpu: - stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32) - stop = xm.mesh_reduce("stop_signal", stop, sum) - torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - trainer.should_stop = int(stop.item()) == trainer.world_size + should_stop = trainer.accelerator_backend.early_stopping_should_stop(pl_module) + trainer.should_stop = should_stop def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 1bc530f5ec125..7d8ac83c92fc7 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -117,28 +117,14 @@ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS +from pytorch_lightning.accelerators.base_backend import Accelerator -try: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True try: from apex import amp except ImportError: amp = None -try: - import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - try: from omegaconf import Container except ImportError: @@ -171,6 +157,7 @@ class TrainerIOMixin(ABC): scaler: ... use_tpu: bool amp_backend: AMPType + accelerator_backend: Accelerator def get_model(self): is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel)) @@ -202,19 +189,8 @@ def restore_weights(self, model: LightningModule): if self.resume_from_checkpoint is not None: self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) - # wait for all models to restore weights - if self.use_ddp or self.use_ddp2: - # wait for all processes to catch up - torch_distrib.barrier() - - # wait for all models to restore weights - if self.on_tpu and XLA_AVAILABLE: - # wait for all processes to catch up - torch_xla.core.xla_model.rendezvous("pl.TrainerIOMixin.restore_weights") - - elif self.use_horovod: - # wait for all processes to catch up - hvd.join() + # wait for all to catch up + self.accelerator_backend.barrier('TrainerIOMixin.restore_weights') # clear cache after restore if self.on_gpu: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b663a6636ab24..eea4609fca0db 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -30,22 +30,6 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary -try: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - class TrainLoop: