diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ec2c531a0ce9..b3bd21cbaca1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -207,6 +207,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018)) * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) * Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010)) + * Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) ### Changed diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e2e8c316f48d1..019fd41d5d1cc 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -22,7 +22,9 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch +from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase @@ -377,6 +379,50 @@ def pre_dispatch(self): self.init_deepspeed() self.barrier() + def _setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + """Setup multiple models and multiple optimizers together. + + Currently only one model paired with a single optimizer is supported. + + Return: + A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single + deepspeed optimizer. + """ + if not (len(models) == len(optimizers) == 1): + raise ValueError( + f"Currently only one model and one optimizer is supported with DeepSpeed." + f" Got {len(models)} models and {len(optimizers)} optimizers instead." + ) + + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # normally we set this to the batch size, but it is not available here unless the user provides it + # as part of the config + self.config.setdefault("train_micro_batch_size_per_gpu", 1) + self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0]) + self._set_deepspeed_activation_checkpointing() + return [self._model], [optimizer] + + def _setup_model_and_optimizer( + self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None + ): + """Initialize one model and one optimizer with an optional learning rate scheduler. + + This calls :func:`deepspeed.initialize` internally. + """ + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( + args=argparse.Namespace(device_rank=self.root_device.index), + config=self.config, + model=model, + model_parameters=model_parameters, # type: ignore + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dist_init_required=False, + ) + return deepspeed_engine, deepspeed_optimizer + def init_deepspeed(self): # check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles # gradient clipping internally @@ -441,18 +487,7 @@ def _initialize_deepspeed_train(self, model): optimizer, lr_scheduler, _ = self._init_optimizers() scheduler = lr_scheduler["scheduler"] - - model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) - model, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( - args=argparse.Namespace(device_rank=self.root_device.index), - config=self.config, - model=model, - model_parameters=model_parameters, - optimizer=optimizer, - lr_scheduler=scheduler, - dist_init_required=False, - ) - + model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler) self._set_deepspeed_activation_checkpointing() # although we set these here, deepspeed manages the specific optimizer logic @@ -568,6 +603,10 @@ def _format_config(self): self._format_precision_config() def _format_batch_size_and_grad_accum_config(self): + # todo: using lite, we do not support these variables within the config + if self.lightning_module is None: + return + if "gradient_accumulation_steps" in self.config: raise MisconfigurationException( "Do not set `gradient_accumulation_steps` in the DeepSpeed config"