Skip to content

Commit

Permalink
PLD release (microsoft#513)
Browse files Browse the repository at this point in the history
* Progressive layer dropping docs (microsoft#499)

* test

* Adding tutorial and news page for pld

* updating the tutorial and posts of PLD

* update the finetune tutorial

* Update PLD tutorial (microsoft#512)

* Update installation instructions

* Format fix

* ZeRO tutorial

* Format fixes

* ZeRO-Offload

* ZeRO and ZeRO-Offload tutorials

* Update navigation page

* Format fixes

* Add yuxhe feedback

* Fix blog post link

* Fix OneBit-Adam link
Tweak scheduler example

* Fix date link

* Add DeepSpeed_Adam

* Add PLD tutorial to navigation

Co-authored-by: Shaden Smith <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>

* updating the pld docs

* DeepSpeed implementation of PLD (microsoft#508)

* DeepSpeed implementation of PLD

* Format fixes

* Formatting fixes

* Fix broken url

* Address PR feedback

* Bump DSE

Co-authored-by: Minjia Zhang <[email protected]>
Co-authored-by: Shaden Smith <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Minjia Zhang <[email protected]>
  • Loading branch information
5 people authored Nov 10, 2020
1 parent e082d47 commit be1147c
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 10 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the
[Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
[[email protected]](mailto:[email protected]) with any additional questions or
comments.
[[email protected]](mailto:[email protected]) with any additional questions or comments.

# Publications
1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: Memory Optimization Towards Training A Trillion Parameter Models. [ArXiv:1910.02054](https://arxiv.org/abs/1910.02054)
21 changes: 21 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@
ADAM_W_MODE_PARAM = "adam_w_mode"


def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
PLD_ENABLED,
PLD_ENABLED_DEFAULT)
else:
return False


def get_pld_params(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP])
pld_params.pop(PLD_ENABLED)
return pld_params
else:
return False


def get_amp_enabled(param_dict):
if AMP in param_dict.keys():
return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT)
Expand Down Expand Up @@ -542,6 +560,9 @@ def _initialize_params(self, param_dict):
self.sparse_attention = get_sparse_attention(param_dict)
self.pipeline = get_pipeline_config(param_dict)

self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict)

def _batch_assertion(self):

train_batch = self.train_batch_size
Expand Down
13 changes: 13 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,16 @@
# Tensorboard job name
TENSORBOARD_JOB_NAME = "job_name"
TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"

# Progressive Layer Drop (PLD)
PROGRESSIVE_LAYER_DROP = "progressive_layer_drop"

# PLD enable signal
PLD_ENABLED = "enabled"
PLD_ENABLED_DEFAULT = False

PLD_THETA = "theta"
PLD_THETA_DEFAULT = 1.0

PLD_GAMMA = "gamma"
PLD_GAMMA_DEFAULT = 0.001
40 changes: 37 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT
TORCH_DISTRIBUTED_DEFAULT_PORT, PLD_THETA, PLD_GAMMA
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger, log_dist
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop

from .utils import ensure_directory_exists

Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(self,
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
self.progressive_layer_drop = None

if dist_init_required is None:
dist_init_required = not dist.is_initialized()
Expand Down Expand Up @@ -192,10 +194,13 @@ def __init__(self,
self.save_zero_checkpoint = False
self._configure_checkpointing(dist_init_required)

if self.pld_enabled():
self.progressive_layer_drop = self._configure_progressive_layer_drop()

if self.global_rank == 0:
self._config.print('DeepSpeedLight configuration')
self._config.print('DeepSpeedEngine configuration')
if self.dump_state():
print_configuration(self, 'DeepSpeedLight')
print_configuration(self, 'DeepSpeedEngine')

def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
Expand Down Expand Up @@ -236,6 +241,18 @@ def _mpi_check(self, args, dist_init_required):
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())

def pld_enabled(self):
return self._config.pld_enabled

def pld_params(self):
return self._config.pld_params

def pld_theta(self):
return self.pld_params()[PLD_THETA]

def pld_gamma(self):
return self.pld_params()[PLD_GAMMA]

def tensorboard_enabled(self):
return self._config.tensorboard_enabled

Expand Down Expand Up @@ -666,6 +683,11 @@ def _configure_zero_optimizer(self, optimizer):

return optimizer

def _configure_progressive_layer_drop(self):
pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma())

return pld

def deepspeed_io(self,
dataset,
batch_size=None,
Expand Down Expand Up @@ -751,6 +773,9 @@ def forward(self, *inputs, **kwargs):
**kwargs: variable length keyword arguments
"""

if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.wall_clock_breakdown():
self.timers('forward_microstep').start()
self.timers('forward').start()
Expand Down Expand Up @@ -931,6 +956,9 @@ def step(self):

# Update the model when we reach gradient accumulation boundaries
if self.is_gradient_accumulation_boundary():
if self.progressive_layer_drop:
self.progressive_layer_drop.update_state(self.global_steps)

self._take_model_step()

self.tput_timer.stop(report_progress)
Expand Down Expand Up @@ -1024,6 +1052,12 @@ def get_mom(self):
else:
return self._get_optimizer_param('betas')

def get_pld_theta(self):
if self.progressive_layer_drop:
return self.progressive_layer_drop.get_theta()
else:
return None

def _report_progress(self, step):
lr = self.get_lr()
mom = self.get_mom()
Expand Down
33 changes: 33 additions & 0 deletions deepspeed/runtime/progressive_layer_drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from deepspeed.utils import log_dist


class ProgressiveLayerDrop(object):
r""" Progressive Layer Dropping (PLD) for model training.
This implements the PLD technique for compressed model training
from this paper: https://arxiv.org/pdf/2010.13369.pdf
Args:
theta (float): a hyper-parameter that controls the trade-off between training time and robustness.
The lower the theta value, the faster the training speed. Default value: 0.5.
gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001.
"""
def __init__(self, theta=0.5, gamma=0.001):
super().__init__()

self.theta = theta
self.gamma = gamma
self.current_theta = 1.0
log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0])

def get_state(self):
kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()}
return kwargs

def get_theta(self):
return self.current_theta

def update_state(self, global_step):
def _prob(x, gamma, p):
return (1. - p) * np.exp(-gamma * x) + p

self.current_theta = _prob(global_step, self.gamma, self.theta)
2 changes: 1 addition & 1 deletion docs/_posts/2020-10-28-progressive-layer-dropping-news.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ new_post: true
date: 2020-10-29 00:00:00
---

We introduce a new technology called progressive layer dropping (PLD) to speedup the pre-training of Transformer-based networks through efficient and robust compressed training. The pre-training step of Transformer networks often suffer from unbearable overall computational expenses. We analyze the training dynamics and stability of Transformer networks and propose PLD to sparsely update Transformer blocks following a progressive dropping schedule, which smoothly increases the layer dropping rate for each mini-batch as training evolves along both the temporal and the model depth dimension. PLD is able to allow the pre-training to be **2.5X faster** to get similar accuracy on downstream tasks and allows the training to be **24% faster** when training the same number of samples, not at the cost of excessive hardware resources.
We introduce a new technology called progressive layer dropping (PLD) to speedup the pre-training of Transformer-based networks through efficient and robust compressed training. The pre-training step of Transformer networks often suffer from unbearable overall computational expenses. We analyze the training dynamics and stability of Transformer networks and propose PLD to sparsely update Transformer blocks following a progressive dropping schedule, which smoothly increases the layer dropping rate for each mini-batch as training evolves along both the temporal and the model depth dimension. PLD is able to allow the pre-training to be **2.5X faster** to get similar accuracy on downstream tasks and allows the training to be **24% faster** when training the same number of samples, not at the cost of excessive hardware resources.

* For detailed technology deep dive, see our [technical report](https://arxiv.org/pdf/2010.13369.pdf).
* For more information on how to use PLD, see our [Progressive layer dropping tutorial](https://www.deepspeed.ai/tutorials/progressive_layer_dropping/).
Expand Down
2 changes: 1 addition & 1 deletion docs/_tutorials/progressive_layer_dropping.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ already been modified to use DeepSpeed. The `ds_train_bert_progressive_layer_dr
bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh
```

Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks.
Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks.

```bash
--progressive_layer_drop
Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ efficient, and effective.
<p align="center"><i><b>Minimal Code Change</b></i></p>

DeepSpeed delivers extreme-scale model training for everyone, from data scientists training on massive supercomputers to those training on low-end clusters or even on a single GPU:
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.

Early adopters of DeepSpeed have already produced
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/simple_model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def step(self, closure=None):
return loss


class PLD_SimpleModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank)

def forward(self, x, y, **kwargs):
pld = kwargs.get('progressive_layer_drop', False)
theta = kwargs.get('pld_theta', 1.0)
hidden_dim = super(PLD_SimpleModel, self).forward(x, y)
return hidden_dim


def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
Expand Down
117 changes: 117 additions & 0 deletions tests/unit/test_pld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
import deepspeed
import pytest
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from common import distributed_test
from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict


@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0])
def test_pld_schedule(tmpdir, theta):
gamma = 0.001

pld_scheduler = ProgressiveLayerDrop(theta, gamma)
for i in range(10):
pld_scheduler.update_state(i)
expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
actual_theta = pld_scheduler.get_theta()
assert expected_theta == actual_theta


@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0])
def test_pld_model(tmpdir, theta):
gamma = 0.001
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": 'Adam',
"params": {
"lr": 0.0001
}
},
"fp16": {
"enabled": True
},
"progressive_layer_drop": {
"enabled": True,
"theta": theta,
"gamma": gamma
}
}

args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = PLD_SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[1])
def _test_pld_model(args, model, hidden_dim, theta, gamma):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())

data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)

for i, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
actual_theta = model.get_pld_theta()
assert expected_theta == actual_theta

_test_pld_model(args=args,
model=model,
hidden_dim=hidden_dim,
theta=theta,
gamma=gamma)


def test_non_pld_model(tmpdir):
gamma = 0.001
theta = 0.5
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": 'Adam',
"params": {
"lr": 0.0001
}
},
"fp16": {
"enabled": True
},
"progressive_layer_drop": {
"enabled": True,
"theta": theta,
"gamma": gamma
}
}

args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[1])
def _test_non_pld_model(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())

data_loader = random_dataloader(model=model,
total_samples=1,
hidden_dim=hidden_dim,
device=model.device)

for i, batch in enumerate(data_loader):
with pytest.raises(TypeError):
loss = model(batch[0], batch[1])

_test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim)

0 comments on commit be1147c

Please sign in to comment.