Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update setup logic in training type plugins [1 / n] #9994

Merged
merged 21 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))

### Changed

Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import torch
import torch.distributed
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
Expand Down Expand Up @@ -181,6 +182,10 @@ def setup_environment(self) -> None:

self.setup_distributed()

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

def _call_children_scripts(self):
# bookkeeping of spawned processes
self._check_can_spawn_children()
Expand Down Expand Up @@ -355,9 +360,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):

def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
self._model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()

def determine_ddp_device_ids(self):
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.distributed
import torch.multiprocessing as mp
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
Expand Down Expand Up @@ -147,6 +148,10 @@ def setup(self) -> None:
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
Expand Down Expand Up @@ -263,9 +268,7 @@ def _register_ddp_hooks(self) -> None:

def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
self._model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()

def determine_ddp_device_ids(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
Expand Down Expand Up @@ -60,6 +61,29 @@ def setup_environment(self) -> None:
def setup(self) -> None:
"""Called by the accelerator to finish setup."""

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Setup multiple models and multiple optimizers together.

The returned objects are expected to be in the same order they were passed in. The default implementation will
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists.
"""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
models = [self._setup_model(model) for model in models]
optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
return models, optimizers

def _setup_model(self, model: Module) -> Module:
"""Performs setup for the model, e.g., by wrapping it by another class."""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
return model

def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Performs setup for the optimizer, e.g., by wrapping it by another class."""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
return optimizer

@property
@abstractmethod
def on_gpu(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,7 @@ def local_rank(self) -> int:

@property
def node_rank(self) -> int:
# some training types define a local rank
# some training types define a node rank
return getattr(self.training_type_plugin, "node_rank", 0)

@property
Expand Down