Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update setup logic in training type plugins (deepspeed) [2 / n] #10009

Merged
merged 10 commits into from
Oct 19, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `setup_models_and_optimizers` for DeepSpeed ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))

### Changed

Expand Down
61 changes: 49 additions & 12 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
Expand Down Expand Up @@ -377,6 +379,48 @@ def pre_dispatch(self):
self.init_deepspeed()
self.barrier()

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Setup multiple models and multiple optimizers together.

Currently only one model paired with a single optimizer is supported.

Return:
A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single
deepspeed optimizer.
"""
if not (len(models) == len(optimizers) == 1):
raise ValueError(
f"Currently only one model and one optimizer is supported with DeepSpeed."
f" Got {len(models)} models and {len(optimizers)} optimizers instead."
)

self.config["train_micro_batch_size_per_gpu"] = 1
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0])
self._set_deepspeed_activation_checkpointing()
return [self._model], [optimizer]

def _setup_model_and_optimizer(
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None
):
"""Initialize one model and one optimizer with an optional learning rate scheduler.

This calls
:func:`deepspeed.initialize` internally.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
args=argparse.Namespace(device_rank=self.root_device.index),
config=self.config,
model=model,
model_parameters=model_parameters, # type: ignore
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dist_init_required=False,
)
return deepspeed_engine, deepspeed_optimizer

def init_deepspeed(self):
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
# gradient clipping internally
Expand Down Expand Up @@ -441,18 +485,7 @@ def _initialize_deepspeed_train(self, model):
optimizer, lr_scheduler, _ = self._init_optimizers()

scheduler = lr_scheduler["scheduler"]

model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
model, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
args=argparse.Namespace(device_rank=self.root_device.index),
config=self.config,
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
lr_scheduler=scheduler,
dist_init_required=False,
)

model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
self._set_deepspeed_activation_checkpointing()

# although we set these here, deepspeed manages the specific optimizer logic
Expand Down Expand Up @@ -568,6 +601,10 @@ def _format_config(self):
self._format_precision_config()

def _format_batch_size_and_grad_accum_config(self):
# todo: using lite, we do not support these variables within the config
if self.lightning_module is None:
return
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

if "gradient_accumulation_steps" in self.config:
raise MisconfigurationException(
"Do not set `gradient_accumulation_steps` in the DeepSpeed config"
Expand Down