forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Shaden Smith <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
- Loading branch information
1 parent
c0d5424
commit 2dea61f
Showing
13 changed files
with
172 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ collections: | |
- megatron.md | ||
- 1Cycle.md | ||
- lrrt.md | ||
- zero.md | ||
|
||
defaults: | ||
- scope: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
--- | ||
title: "ZeRO-Offload" | ||
--- | ||
We recommend that you read the tutorials on [Getting Started](/getting-started/) and [ZeRO](/zero/) before stepping through this tutorial. | ||
|
||
ZeRO-Offload is a ZeRO optimization that offloads the optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be efficiently trained on a single GPU. In this tutorial we will use ZeRO-Offload to train a 10-billion parameter GPT-2 model in DeepSpeed. Furthermore, *using ZeRO-Offload in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed. | ||
|
||
## ZeRO-Offload Overview | ||
For large model training, optimizers such as [Adam](https://arxiv.org/abs/1412.6980), can consume a significant amount of GPU compute and memory. ZeRO-Offload reduces the GPU compute and memory requirements of such models by leveraging compute and memory resources on the host CPU to execute the optimizer. Furthermore, to prevent the optimizer from becoming a bottleneck, ZeRO-Offload uses DeepSpeed's highly optimized CPU implementation of Adam called [DeeSpeedCPUAdam](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/ops/adam). DeepSpeedCPUAdam is 5X--7X faster than the standard PyTorch implementation. To deep dive into the design and performance of ZeRO-Offload, please see our blog post [[XXXX]()]. | ||
|
||
## Training Environment | ||
For this tutorial, we will configure a 10 billion parameter GPT-2 model using the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) GPT-2 code. We advise stepping through the Megatron-LM [tutorial](/megatron/) if you have not previously done so. We will use a single [NVIDIA Tesla V100-SXM3 Tensor Core GPU](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM for this exercise. | ||
|
||
## Training a 10B parameter GPT-2 on 1 V100 GPU | ||
We need to make changes to the Megatron-LM launch script and to the DeepSpeed configuration json. | ||
|
||
### Megatron-LM GPT-2 launch script changes | ||
We need to apply two changes to the launch script for the DeepSpeed Megatron-LM GPT-2 model. The first change is to configure a 10B parameter GPT-2 model, which can be achieved by the following set of changes: | ||
|
||
```bash | ||
--model-parallel-size 1 \ | ||
--num-layers 50 \ | ||
--hidden-size 4096 \ | ||
--num-attention-heads 32 \ | ||
--batch-size 10 \ | ||
--d \ | ||
--deepspeed_config ds_zero_offload.config \ | ||
--cpu_optimizer \ | ||
``` | ||
|
||
Most of the flags in the changes above should be familiar if you have stepped through the Megatron-LM [tutorial](/megatron/), except for the **_--cpu_optimizer_**. This flag informs the model script to pass a CPU-based Adam optimizer, rather than a GPU-based one, to DeepSpeed as the client optimizer. It is very important that this flag be used when training with ZeRO-Offload to ensure correct operation of the DeepSpeed engine. | ||
|
||
Second, we need to apply the following changes to ensure that only one GPU is used for training. | ||
```bash | ||
deepspeed --num_nodes 1 --num_gpus 1 ... | ||
``` | ||
|
||
### DeepSpeed Configuration Changes | ||
ZeRO-Offload leverages much for ZeRO stage 2 mechanisms, and so the configuration changes to enable ZeRO-Offload is an extension of those required to enable ZeRO stage 2. The **zero_optimization** key to enable ZeRO-Offload is shown below: | ||
|
||
```json | ||
{ | ||
"zero_optimization": { | ||
"stage": 2, | ||
"cpu_offload": true, | ||
"contiguous_gradients": true, | ||
"overlap_comm": true | ||
} | ||
} | ||
``` | ||
|
||
As seen above, in addition to setting the _stage_ field to **2** (to enable ZeRO stage 2), we also need to set _cpu_offload_ flag to **true** enable ZeRO-Offload optimizations. In addition, we can set other ZeRO stage 2 optimization flags, such as _overlap_comm_ to tune ZeRO-Offload performance. With these changes we can now run the model. We share some screenshots of the training below. | ||
|
||
Here is a screenshot of the training log: | ||
|
||
![ZERO_OFFLOAD_DP1_10B_LOG](/assets/images/zero_offload_dp1_10B_log.png) | ||
|
||
Here is a screenshot of nvidia-smi showing that only GPU 0 is active during training: | ||
|
||
![ZERO_OFFLOAD_DP1_10B_SMI](/assets/images/zero_offload_dp1_10B_smi.png) | ||
|
||
Finally, here is a screenshot of htop showing host CPU and memory activity during optimizer computation: | ||
|
||
![ZERO_OFFLOAD_DP1_10B_SMI](/assets/images/zero_offload_dp1_10B_cpu.png) | ||
|
||
Congratulations! You have completed the ZeRO-Offload tutorial. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
--- | ||
title: "Zero Redundancy Optimizer (ZeRO)" | ||
--- | ||
If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/megatron/) before stepping through this tutorial. | ||
|
||
In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with billions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed. | ||
|
||
## ZeRO Overview | ||
ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3). | ||
|
||
* **Stage 1**: The optimizer states (e.g., for [Adam optimizer](https://arxiv.org/abs/1412.6980), 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. | ||
|
||
* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. | ||
|
||
## Training environment | ||
We use the DeepSpeed [Megatrom-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM. | ||
|
||
## Enabling ZeRO Optimization | ||
To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed json configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). | ||
|
||
### Training a 1.5B Parameter GPT-2 model | ||
We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script: | ||
|
||
```bash | ||
--model-parallel-size 1 \ | ||
--num-layers 48 \ | ||
--hidden-size 1600 \ | ||
--num-attention-heads 16 \ | ||
--batch-size 1 \ | ||
--d \ | ||
--deepspeed_config ds_zero_stage_1.config \ | ||
``` | ||
|
||
Training this model without ZeRO fails with an out-of-memory (OOM) error as shown below: | ||
![OOM_DP8_1.5B_model](/assets/images/oom_dp8_1.5B_log.png) | ||
|
||
A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed json config file as below: | ||
|
||
```json | ||
{ | ||
"zero_optimization": { | ||
"stage":1, | ||
"reduce_bucket_size": 500000000 | ||
} | ||
} | ||
``` | ||
As seen above, we set two fields in the **zero_optimization** key. Specifically we set the _stage_ field to 1, and the optional _reduce_bucket_size_ for gradient reduction to 50M. With ZeRO stage 1 enabled, the model can now train smoothly on 8 GPUs without running out of memory. Below we provide some screenshots of the model training: | ||
|
||
![ZERO1_DP8_1.5B_LOG](/assets/images/zero1_dp8_1.5B_log.png) | ||
|
||
![ZERO1_DP8_1.5B_SMI](/assets/images/zero1_dp8_1.5B_smi.png) | ||
|
||
From the nvidia-smi screenshot above we can see that that only GPUs 0--7 are being used for training the model. With ZeRO stage 1 we can further reduce the per-device memory consumption by increasing the data parallelism degree. These memory savings can be leveraged to either increase model size and/or batch size. In contrast, such benefits are not possible with data parallelism alone. | ||
|
||
### Training a 10B Parameter GPT-2 model | ||
ZeRO stage 2 optimizations further increases the size of models that can be trained using data parallelism. We show this training a model with 10B parameters using 32 V100 GPUs. First, we need to configure a 10B parameter model. This can be done by applying the following GPT-2 model configuration changes to the DeepSpeed launch script. | ||
|
||
```bash | ||
--model-parallel-size 1 \ | ||
--num-layers 50 \ | ||
--hidden-size 4096 \ | ||
--num-attention-heads 32 \ | ||
--batch-size 1 \ | ||
--d \ | ||
--deepspeed_config ds_zero_stage_2.config \ | ||
``` | ||
|
||
Next, we need to update the DeepSpeed json configuration, as shown below, to enable ZeRO stage 2 optimizations: | ||
|
||
```json | ||
{ | ||
"zero_optimization": { | ||
"stage":2, | ||
"contiguous_gradients": true, | ||
"overlap_comm": true, | ||
"reduce_scatter": true, | ||
"reduce_bucket_size": 50000000, | ||
"allgather_bucket_size": 500000000 | ||
} | ||
} | ||
``` | ||
|
||
In the above changes, we have set the _stage_ field to 2, and configured other optimization knobs that are available in ZeRO stage 2. For example, we have enabled _contiguous_gradients_ to reduce memory fragmenation during backward pass. A full description of these optimization knobs is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). With these changes, we can now run the launch the training run. | ||
|
||
Here is a screenshot of the training log: | ||
|
||
![ZERO2_DP32_10B_LOG](/assets/images/zero2_dp32_10B_log.png) | ||
|
||
Here is a screenshot of nvidia-smi show GPU activity during training: | ||
|
||
![ZERO2_DP32_10B_SMI](/assets/images/zero2_dp32_10B_smi.png) | ||
|
||
Congratulations! You have completed the ZeRO tutorial. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.