Skip to content

Commit

Permalink
ref: move specific accelerator code x/n (#3457)
Browse files Browse the repository at this point in the history
* ref: organize args x/n

* ref: move specific accelerator code x/n

* ref: move specific accelerator code x/n

* ref: move specific accelerator code x/n
  • Loading branch information
williamFalcon authored Sep 11, 2020
1 parent bd5f53c commit ef20310
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 61 deletions.
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ def _clip_gradients(self, optimizer):

def on_train_epoch_end(self):
pass

def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop
8 changes: 8 additions & 0 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -199,3 +200,10 @@ def test_step_end(self, output):

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop
8 changes: 8 additions & 0 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -278,3 +279,10 @@ def _check_can_spawn_children(self):

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop
8 changes: 8 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist

try:
from apex import amp
Expand Down Expand Up @@ -195,3 +196,10 @@ def test_step(self, args):

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop
7 changes: 7 additions & 0 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,10 @@ def clip_gradients(self, optimizer):

def barrier(self, name: str = None):
torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}")

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32)
stop = xm.mesh_reduce("stop_signal", stop, sum)
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
should_stop = int(stop.item()) == self.trainer.world_size
return should_stop
19 changes: 2 additions & 17 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"""
import numpy as np
import torch
import torch.distributed as dist

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -216,22 +215,8 @@ def _run_early_stopping_check(self, trainer, pl_module):
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
self._stop_distributed_training(trainer, pl_module)

def _stop_distributed_training(self, trainer, pl_module):

# in ddp make sure all processes stop when one is flagged
if trainer.use_ddp or trainer.use_ddp2:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
trainer.should_stop = stop == trainer.world_size

if trainer.use_tpu:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
stop = xm.mesh_reduce("stop_signal", stop, sum)
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
trainer.should_stop = int(stop.item()) == trainer.world_size
should_stop = trainer.accelerator_backend.early_stopping_should_stop(pl_module)
trainer.should_stop = should_stop

def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
Expand Down
32 changes: 4 additions & 28 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,14 @@
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
from apex import amp
except ImportError:
amp = None

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True

try:
from omegaconf import Container
except ImportError:
Expand Down Expand Up @@ -171,6 +157,7 @@ class TrainerIOMixin(ABC):
scaler: ...
use_tpu: bool
amp_backend: AMPType
accelerator_backend: Accelerator

def get_model(self):
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel))
Expand Down Expand Up @@ -202,19 +189,8 @@ def restore_weights(self, model: LightningModule):
if self.resume_from_checkpoint is not None:
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)

# wait for all models to restore weights
if self.use_ddp or self.use_ddp2:
# wait for all processes to catch up
torch_distrib.barrier()

# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.TrainerIOMixin.restore_weights")

elif self.use_horovod:
# wait for all processes to catch up
hvd.join()
# wait for all to catch up
self.accelerator_backend.barrier('TrainerIOMixin.restore_weights')

# clear cache after restore
if self.on_gpu:
Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,6 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainLoop:

Expand Down

0 comments on commit ef20310

Please sign in to comment.