Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update setup logic in training type plugins [1 / n] #9994

Merged
merged 21 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import numpy as np
import torch
import torch.distributed
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -181,6 +183,9 @@ def setup_environment(self) -> None:

self.setup_distributed()

def setup_model(self, model: Module) -> Module:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

def _call_children_scripts(self):
# bookkeeping of spawned processes
self._check_can_spawn_children()
Expand Down Expand Up @@ -355,9 +360,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):

def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
self._model = self.setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()

def determine_ddp_device_ids(self):
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.distributed
import torch.multiprocessing as mp
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
Expand Down Expand Up @@ -147,6 +148,9 @@ def setup(self) -> None:
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()

def setup_model(self, model: Module) -> Module:
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
Expand Down Expand Up @@ -256,9 +260,7 @@ def _register_ddp_hooks(self) -> None:

def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
self._model = self.setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()

def determine_ddp_device_ids(self):
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import List, Optional, Sequence

import torch
from torch.nn import DataParallel
from torch.nn import DataParallel, Module
from torch.optim import Optimizer

from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand Down Expand Up @@ -54,7 +55,10 @@ def world_size(self) -> int:
def setup(self) -> None:
# model needs to be moved to the device before it is wrapped
self.model_to_device()
self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices)
self._model = self.setup_model(LightningParallelModule(self._model))

def setup_model(self, model: Module) -> Module:
return DataParallel(module=model, device_ids=self.parallel_devices)

def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
Expand Down Expand Up @@ -60,6 +61,19 @@ def setup_environment(self) -> None:
def setup(self) -> None:
"""Called by the accelerator to finish setup."""

def setup_models_and_optimizers(
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
models = [self.setup_model(model) for model in models]
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
return models, optimizers

def setup_model(self, model: Module) -> Module:
return model

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
return optimizer

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
@property
@abstractmethod
def on_gpu(self) -> bool:
Expand Down