diff --git a/DeepSpeedExamples b/DeepSpeedExamples index a79272cc8b8f..fa1d1a71c486 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit a79272cc8b8f0c5b66c803e581a1355341eacb77 +Subproject commit fa1d1a71c48623db8a091d9cf636a5fe3b8f43c7 diff --git a/README.md b/README.md index 0ca3a22ab674..a3c78bb16a36 100755 --- a/README.md +++ b/README.md @@ -155,8 +155,7 @@ all repos using our CLA. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact -[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or -comments. +[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. # Publications 1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: Memory Optimization Towards Training A Trillion Parameter Models. [ArXiv:1910.02054](https://arxiv.org/abs/1910.02054) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 10ceab52a8dd..4a56aafbc539 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -30,6 +30,24 @@ ADAM_W_MODE_PARAM = "adam_w_mode" +def get_pld_enabled(param_dict): + if PROGRESSIVE_LAYER_DROP in param_dict.keys(): + return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], + PLD_ENABLED, + PLD_ENABLED_DEFAULT) + else: + return False + + +def get_pld_params(param_dict): + if PROGRESSIVE_LAYER_DROP in param_dict.keys(): + pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP]) + pld_params.pop(PLD_ENABLED) + return pld_params + else: + return False + + def get_amp_enabled(param_dict): if AMP in param_dict.keys(): return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT) @@ -542,6 +560,9 @@ def _initialize_params(self, param_dict): self.sparse_attention = get_sparse_attention(param_dict) self.pipeline = get_pipeline_config(param_dict) + self.pld_enabled = get_pld_enabled(param_dict) + self.pld_params = get_pld_params(param_dict) + def _batch_assertion(self): train_batch = self.train_batch_size diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 0bedc10ddac4..a731865714fe 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -291,3 +291,16 @@ # Tensorboard job name TENSORBOARD_JOB_NAME = "job_name" TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName" + +# Progressive Layer Drop (PLD) +PROGRESSIVE_LAYER_DROP = "progressive_layer_drop" + +# PLD enable signal +PLD_ENABLED = "enabled" +PLD_ENABLED_DEFAULT = False + +PLD_THETA = "theta" +PLD_THETA_DEFAULT = 1.0 + +PLD_GAMMA = "gamma" +PLD_GAMMA_DEFAULT = 0.001 diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0932ef9e4998..b1c7ade2d12a 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -26,13 +26,14 @@ from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - TORCH_DISTRIBUTED_DEFAULT_PORT + TORCH_DISTRIBUTED_DEFAULT_PORT, PLD_THETA, PLD_GAMMA from deepspeed.runtime.zero.constants import \ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.csr_tensor import CSRTensor import deepspeed.runtime.lr_schedules as lr_schedules from deepspeed.utils import logger, log_dist from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer +from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from .utils import ensure_directory_exists @@ -127,6 +128,7 @@ def __init__(self, self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_dp_world_size = None self.enable_backward_allreduce = True + self.progressive_layer_drop = None if dist_init_required is None: dist_init_required = not dist.is_initialized() @@ -192,10 +194,13 @@ def __init__(self, self.save_zero_checkpoint = False self._configure_checkpointing(dist_init_required) + if self.pld_enabled(): + self.progressive_layer_drop = self._configure_progressive_layer_drop() + if self.global_rank == 0: - self._config.print('DeepSpeedLight configuration') + self._config.print('DeepSpeedEngine configuration') if self.dump_state(): - print_configuration(self, 'DeepSpeedLight') + print_configuration(self, 'DeepSpeedEngine') def _mpi_check(self, args, dist_init_required): if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: @@ -236,6 +241,18 @@ def _mpi_check(self, args, dist_init_required): assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( world_size, dist.get_world_size()) + def pld_enabled(self): + return self._config.pld_enabled + + def pld_params(self): + return self._config.pld_params + + def pld_theta(self): + return self.pld_params()[PLD_THETA] + + def pld_gamma(self): + return self.pld_params()[PLD_GAMMA] + def tensorboard_enabled(self): return self._config.tensorboard_enabled @@ -666,6 +683,11 @@ def _configure_zero_optimizer(self, optimizer): return optimizer + def _configure_progressive_layer_drop(self): + pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma()) + + return pld + def deepspeed_io(self, dataset, batch_size=None, @@ -751,6 +773,9 @@ def forward(self, *inputs, **kwargs): **kwargs: variable length keyword arguments """ + if self.module.training and self.progressive_layer_drop: + kwargs.update(self.progressive_layer_drop.get_state()) + if self.wall_clock_breakdown(): self.timers('forward_microstep').start() self.timers('forward').start() @@ -931,6 +956,9 @@ def step(self): # Update the model when we reach gradient accumulation boundaries if self.is_gradient_accumulation_boundary(): + if self.progressive_layer_drop: + self.progressive_layer_drop.update_state(self.global_steps) + self._take_model_step() self.tput_timer.stop(report_progress) @@ -1024,6 +1052,12 @@ def get_mom(self): else: return self._get_optimizer_param('betas') + def get_pld_theta(self): + if self.progressive_layer_drop: + return self.progressive_layer_drop.get_theta() + else: + return None + def _report_progress(self, step): lr = self.get_lr() mom = self.get_mom() diff --git a/deepspeed/runtime/progressive_layer_drop.py b/deepspeed/runtime/progressive_layer_drop.py new file mode 100755 index 000000000000..770978a940a0 --- /dev/null +++ b/deepspeed/runtime/progressive_layer_drop.py @@ -0,0 +1,33 @@ +import numpy as np +from deepspeed.utils import log_dist + + +class ProgressiveLayerDrop(object): + r""" Progressive Layer Dropping (PLD) for model training. + This implements the PLD technique for compressed model training + from this paper: https://arxiv.org/pdf/2010.13369.pdf + Args: + theta (float): a hyper-parameter that controls the trade-off between training time and robustness. + The lower the theta value, the faster the training speed. Default value: 0.5. + gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001. + """ + def __init__(self, theta=0.5, gamma=0.001): + super().__init__() + + self.theta = theta + self.gamma = gamma + self.current_theta = 1.0 + log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0]) + + def get_state(self): + kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()} + return kwargs + + def get_theta(self): + return self.current_theta + + def update_state(self, global_step): + def _prob(x, gamma, p): + return (1. - p) * np.exp(-gamma * x) + p + + self.current_theta = _prob(global_step, self.gamma, self.theta) diff --git a/docs/_posts/2020-10-28-progressive-layer-dropping-news.md b/docs/_posts/2020-10-28-progressive-layer-dropping-news.md index baa7ab7acfe5..5659cf818987 100755 --- a/docs/_posts/2020-10-28-progressive-layer-dropping-news.md +++ b/docs/_posts/2020-10-28-progressive-layer-dropping-news.md @@ -7,7 +7,7 @@ new_post: true date: 2020-10-29 00:00:00 --- -We introduce a new technology called progressive layer dropping (PLD) to speedup the pre-training of Transformer-based networks through efficient and robust compressed training. The pre-training step of Transformer networks often suffer from unbearable overall computational expenses. We analyze the training dynamics and stability of Transformer networks and propose PLD to sparsely update Transformer blocks following a progressive dropping schedule, which smoothly increases the layer dropping rate for each mini-batch as training evolves along both the temporal and the model depth dimension. PLD is able to allow the pre-training to be **2.5X faster** to get similar accuracy on downstream tasks and allows the training to be **24% faster** when training the same number of samples, not at the cost of excessive hardware resources. +We introduce a new technology called progressive layer dropping (PLD) to speedup the pre-training of Transformer-based networks through efficient and robust compressed training. The pre-training step of Transformer networks often suffer from unbearable overall computational expenses. We analyze the training dynamics and stability of Transformer networks and propose PLD to sparsely update Transformer blocks following a progressive dropping schedule, which smoothly increases the layer dropping rate for each mini-batch as training evolves along both the temporal and the model depth dimension. PLD is able to allow the pre-training to be **2.5X faster** to get similar accuracy on downstream tasks and allows the training to be **24% faster** when training the same number of samples, not at the cost of excessive hardware resources. * For detailed technology deep dive, see our [technical report](https://arxiv.org/pdf/2010.13369.pdf). * For more information on how to use PLD, see our [Progressive layer dropping tutorial](https://www.deepspeed.ai/tutorials/progressive_layer_dropping/). diff --git a/docs/_tutorials/progressive_layer_dropping.md b/docs/_tutorials/progressive_layer_dropping.md index 2ddbb69c58ad..4958717f8d09 100755 --- a/docs/_tutorials/progressive_layer_dropping.md +++ b/docs/_tutorials/progressive_layer_dropping.md @@ -18,7 +18,7 @@ already been modified to use DeepSpeed. The `ds_train_bert_progressive_layer_dr bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh ``` -Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks. +Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks. ```bash --progressive_layer_drop diff --git a/docs/index.md b/docs/index.md index 3c3713724edc..13d1ff89873b 100755 --- a/docs/index.md +++ b/docs/index.md @@ -12,9 +12,9 @@ efficient, and effective.

Minimal Code Change

DeepSpeed delivers extreme-scale model training for everyone, from data scientists training on massive supercomputers to those training on low-end clusters or even on a single GPU: -* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters. +* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters. * Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models. -* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers. +* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers. * Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. Early adopters of DeepSpeed have already produced diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py old mode 100644 new mode 100755 index a8a383a3aac9..0dc8711bddcd --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -101,6 +101,17 @@ def step(self, closure=None): return loss +class PLD_SimpleModel(SimpleModel): + def __init__(self, hidden_dim, empty_grad=False, rank=0): + super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank) + + def forward(self, x, y, **kwargs): + pld = kwargs.get('progressive_layer_drop', False) + theta = kwargs.get('pld_theta', 1.0) + hidden_dim = super(PLD_SimpleModel, self).forward(x, y) + return hidden_dim + + def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half): batch_size = model.train_micro_batch_size_per_gpu() train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) diff --git a/tests/unit/test_pld.py b/tests/unit/test_pld.py new file mode 100755 index 000000000000..784aeff0338f --- /dev/null +++ b/tests/unit/test_pld.py @@ -0,0 +1,117 @@ +import numpy as np +import deepspeed +import pytest +from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop +from common import distributed_test +from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict + + +@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) +def test_pld_schedule(tmpdir, theta): + gamma = 0.001 + + pld_scheduler = ProgressiveLayerDrop(theta, gamma) + for i in range(10): + pld_scheduler.update_state(i) + expected_theta = (1. - theta) * np.exp(-gamma * i) + theta + actual_theta = pld_scheduler.get_theta() + assert expected_theta == actual_theta + + +@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) +def test_pld_model(tmpdir, theta): + gamma = 0.001 + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.0001 + } + }, + "fp16": { + "enabled": True + }, + "progressive_layer_drop": { + "enabled": True, + "theta": theta, + "gamma": gamma + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = PLD_SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_pld_model(args, model, hidden_dim, theta, gamma): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + expected_theta = (1. - theta) * np.exp(-gamma * i) + theta + actual_theta = model.get_pld_theta() + assert expected_theta == actual_theta + + _test_pld_model(args=args, + model=model, + hidden_dim=hidden_dim, + theta=theta, + gamma=gamma) + + +def test_non_pld_model(tmpdir): + gamma = 0.001 + theta = 0.5 + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.0001 + } + }, + "fp16": { + "enabled": True + }, + "progressive_layer_drop": { + "enabled": True, + "theta": theta, + "gamma": gamma + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_non_pld_model(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=1, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + with pytest.raises(TypeError): + loss = model(batch[0], batch[1]) + + _test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim)