Skip to content

Commit

Permalink
[docs] Update deepspeed docs, add some more information and link to s…
Browse files Browse the repository at this point in the history
…treamlit (#8691)
  • Loading branch information
Sean Naren authored Aug 3, 2021
1 parent a1be621 commit 49d03f8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
46 changes: 38 additions & 8 deletions docs/source/advanced/advanced_gpu.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
Advanced GPU Optimized Training
===============================
Advanced GPU Optimized Training & Model Parallelism
===================================================

When training large models, fitting larger batch sizes, or trying to increase throughput using multi-GPU compute, Lightning provides advanced optimized distributed training plugins to support these cases and offer substantial improvements in memory usage.

In many cases these plugins are some flavour of model parallelism however we only introduce concepts at a high level to get you started. Refer to the `FairScale documentation <https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html>`__ for more information about model parallelism.

Note that some of the extreme memory saving configurations will affect the speed of training. This Speed/Memory trade-off in most cases can be adjusted.

Some of these memory-efficient plugins rely on offloading onto other forms of memory, such as CPU RAM or NVMe. This means you can even see memory benefits on a **single GPU**, using a plugin such as :ref:`deepspeed-zero-stage-3-offload`.
Expand All @@ -14,6 +16,8 @@ If you would like to stick with PyTorch DDP, see :ref:`ddp-optimizations`.

Unlike PyTorch's DistributedDataParallel (DDP) where the maximum trainable model size and batch size do not change with respect to the number of GPUs, memory-optimized plugins can accommodate bigger models and larger batches as more GPUs are used. This means as you scale up the number of GPUs, you can reach the number of model parameters you'd like to train.

There are many considerations when choosing a plugin as described below. In addition, check out the visualization of various plugin benchmarks using `minGPT <https://github.com/SeanNaren/minGPT>`__ `here <https://share.streamlit.io/seannaren/mingpt/streamlit/app.py>`__.

Pre-training vs Fine-tuning
"""""""""""""""""""""""""""

Expand Down Expand Up @@ -216,6 +220,9 @@ If you run into an issue with the install or later in training, ensure that the

DeepSpeed currently only supports single optimizer, single scheduler within the training loop.

When saving a checkpoint we rely on DeepSpeed which saves a directory containing the model and various components.


.. _deepspeed-zero-stage-2:

DeepSpeed ZeRO Stage 2
Expand All @@ -224,9 +231,6 @@ DeepSpeed ZeRO Stage 2
By default, we enable `DeepSpeed ZeRO Stage 2 <https://www.deepspeed.ai/tutorials/zero/#zero-overview>`_, which partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team.
As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around ``3.6GB`` of VRAM to use during distributed communications, which can be tweaked when instantiating the plugin described in a few sections below.

.. note::
To use ZeRO, you must use ``precision=16``.

.. code-block:: python
from pytorch_lightning import Trainer
Expand All @@ -247,9 +251,6 @@ DeepSpeed ZeRO Stage 2 Offload

Below we show an example of running `ZeRO-Offload <https://www.deepspeed.ai/tutorials/zero-offload/>`_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption.

.. note::
To use ZeRO-Offload, you must use ``precision=16``.

.. code-block:: python
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -332,6 +333,10 @@ Below we describe how to enable all of these to see benefit. **With all these im

Also please have a look at our :ref:`deepspeed-zero-stage-3-tips` which contains a lot of helpful information when configuring your own models.

.. note::

When saving a model using DeepSpeed and Stage 3, model states and optimizer states will be saved in separate sharded states (based on the world size). See :ref:`deepspeed-zero-stage-3-single-file` to obtain a single checkpoint file.

.. code-block:: python
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -399,6 +404,10 @@ DeepSpeed ZeRO Stage 3 Offload

DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to reduce memory usage as ZeRO Stage 2 does, however additionally allows you to offload the parameters as well for even more memory saving.

.. note::

When saving a model using DeepSpeed and Stage 3, model states and optimizer states will be saved in separate sharded states (based on the world size). See :ref:`deepspeed-zero-stage-3-single-file` to obtain a single checkpoint file.

.. code-block:: python
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -516,6 +525,27 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig
* When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed
* We also support sharded checkpointing. By passing ``save_full_weights=False`` to the ``DeepSpeedPlugin``, we'll save shards of the model which allows you to save extremely large models. However to load the model and run test/validation/predict you must use the Trainer object.

.. _deepspeed-zero-stage-3-single-file:

Collating Single File Checkpoint for DeepSpeed ZeRO Stage 3
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

After training using ZeRO Stage 3, you'll notice that your checkpoints are a directory of sharded model and optimizer states. If you'd like to collate a single file from the checkpoint directory please use the below command, which handles all the Lightning states additionally when collating the file.

.. code-block:: python
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
# lightning deepspeed has saved a directory instead of a file
save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/"
output_path = "lightning_model.pt"
convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)
.. warning::

This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the `resume_from_checkpoint` Trainer argument. Ensure to keep the sharded checkpoint directory if this is required.

Custom DeepSpeed Config
"""""""""""""""""""""""

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,12 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
filepath: write-target file's path
"""
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
# todo (sean): Add link to docs once docs are merged.
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
"If a single file is required after training, see <TODO> for instructions."
"If a single file is required after training, "
"see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#"
"deepspeed-zero-stage-3-single-file for instructions."
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
Expand Down

0 comments on commit 49d03f8

Please sign in to comment.