Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Nemo Distributed Checkpoint User Guide #11489

Merged
merged 8 commits into from
Dec 6, 2024
178 changes: 174 additions & 4 deletions docs/source/checkpoints/dist_ckpt.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,125 @@
Distributed Checkpoints
NeMo Distributed Checkpoint User Guide
=======================

This guide provides details about the distributed checkpoints format from Megatron Core.
This guide provides details about the distributed checkpoints best practices from NeMo Megatron Core.


Introduction
------------
--------------

Megatron Core is an open-source, PyTorch-based library that provides a collection of GPU optimization techniques including various parallelisms (data, tensor, pipeline, context, and expert parallelism). NeMo Framework is an end-to-end LLM training framework that builds on top of the Megatron Core library.

In large-scale training, checkpoints are used to periodically save intermediate model states (including model weights, optimizer states, and other necessary metadata). This allows for easy recovery if the training process is interrupted.

NeMo Distributed Checkpoint, part of the Megatron Core library, refers to saving the state of a distributed training job across multiple GPUs or nodes. This approach aims to reduce memory overhead and improve GPU utilization. It also provides users with the flexibility to resume training using different parallelism strategies.

**Megatron Core Library**

Model parallel training requires parallelism-aware checkpointing.
Megatron Core provides a checkpointing library capable of handling all types of parallelisms used in LLM training.
Although the distributed checkpointing library is targeted at the Megatron Core model, it can also be used with other models, as long as proper integration is implemented.

The library provides two main entrypoints: ``dist_checkpointing.save`` and ``dist_checkpointing.load`` which are meant to replace the ``torch.save`` and ``torch.load`` in the regular checkpointing flow.
Apart from that, it provides a mechanism to define how different types of local tensors should be combined and split in the global checkpoint.


Mechanism
--------------
The NeMo Distributed Checkpoint enables saving and loading models from multiple ranks in parallel. It employs a novel strategy called Fully Parallel Saving (FPS) to partition the optimizer states, gradients, and model parameters across all GPU ranks. When saving the checkpoint of a distributed optimizer, each DP rank holds its shard of the optimizer state and independently writes its shard to the shared storage (grad buffer).

When loading the checkpoint, each DP rank reads its corresponding checkpoint file (shard) to recover. If different parallelism strategies are needed (e.g., tensor parallelism, pipeline parallelism), each rank can also access other checkpoint files to transfer data to the correct locations.

NeMo allows users to resume training from a checkpoint saved with different tensor and pipeline parallelism degrees, offering the flexibility to adjust training configurations as needed.

The following figure illustrates fully parallel saving in NeMo Framework, utilizing data-parallel replicas for writing across nodes.

.. image:: https://github.com/NVIDIA/NeMo/releases/download/v2.0.0/asset-nemo-dist-ckpt-explain-0.png


*Figure 1. Fully parallel saving in NeMo Framework uses the data-parallel replicas for parallel writing across nodes*

The following figure illustrates asynchronous saving in NeMo Framework, where checkpoints are saved in the background while training continues. Asynchronous parallel saving allows model parameters to be copied to the CPU first before persisting the checkpoint to stable storage in the background. This process minimizes interruptions to the main training, thereby speeding up the distributed checkpointing process.

.. image:: https://github.com/NVIDIA/NeMo/releases/download/v2.0.0/asset-nemo-dist-ckpt-explain-1.png


*Figure 2. Asynchronous saving in NeMo Framework saves checkpoint at the background in parallel with training*


Parameter Tuning
--------------

You can configure distributed checkpoints in NeMo pre-training and fine-tuning jobs.

In the `NeMo 1.0 YAML config file <https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/migration/checkpointing.html>`__ or `NeMo 2.0 MegatronStrategy <https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/migration/checkpointing.html>`__, you can enable and tune these parameters.

The latest NeMo version is Nemo 2.0 (NGC container ``nvcr.io/nvidia/nemo:24.09``).


Best Practices
^^^^^^^^^^^

Here are best practices for configuring distributed checkpoints in NeMo:

.. code-block:: python

dist_ckpt_format: 'torch_dist'

dist_ckpt_load_on_device: True

dist_ckpt_parallel_save: True

dist_ckpt_parallel_save_within_dp: False

dist_ckpt_parallel_load: True

dist_ckpt_torch_dist_multiproc: 2

dist_ckpt_assume_constant_structure: False

dist_ckpt_parallel_dist_opt: True

dist_ckpt_load_strictness: null


Here's a summary of the checkpoint format options and related parameters:

dist_ckpt_format
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Checkpoint format used for saving. Options are ``torch_dist`` and ``zarr``. PyTorch Distributed (``torch_dist``) is the recommended format. The saving format can differ from the format used for resuming a job. The loading format is auto-detected.

dist_ckpt_load_on_device
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Determines whether to load checkpoint weights directly on GPU or CPU. If True, weights are loaded on GPU. This currently affects only the ``zarr`` format.

dist_ckpt_parallel_save
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Each worker writes its own part of the distributed checkpoint, meaning each DP rank saves its checkpoint shard independently. This applies to model weights or a non-distributed optimizer state. Distributed optimizer parallelization is controlled by the ``dist_ckpt_parallel_dist_opt`` flag (see below).

dist_ckpt_parallel_save_within_dp
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Controls whether NCCL parallelizes the save within the Data Parallel domain. If False, saving is parallelized across the entire world size (number of nodes * number of GPUs). If True, saving is parallelized only within the Data Parallel domain. Setting this to True can reduce latency, but may cause NCCL errors in some setups.

dist_ckpt_parallel_load
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Each worker loads part of the distributed checkpoint and exchanges it with NCCL, meaning each DP rank loads its checkpoint shard independently. This might use extra GPU memory and is critical for large DP setups. If True, the checkpoint is read from storage only once; otherwise, the model weights part is read from storage DP times.

dist_ckpt_torch_dist_multiproc
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Number of extra processes per rank used during checkpoint save with the ``torch_dist`` format. This equals the number of checkpoint files created by each rank. Increasing this number can help saturate the write bandwidth. The default is 2.

dist_ckpt_assume_constant_structure
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Set to True only if the state dict structure remains constant during a single training job (including startup, data loading, training setup, and actual training). This allows caching some computations across checkpoint saves and can reduce saving time starting from the third checkpoint save in the current process.

dist_ckpt_parallel_dist_opt
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Enables parallel save/load of a distributed optimizer. Set to True to save the optimizer state in a reshardable format (allowing changes in TP, PP, etc., upon resume). Set to False to minimize the number of checkpoint files.

dist_ckpt_load_strictness
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Defines behavior for checkpoint key mismatches during loading. Options are ``assume_ok_unexpected`` (default, tries loading without any check), ``log_all`` (logs mismatches), and ``raise_all`` (raises mismatches). Setting to ``log_all`` results in a non-strict state dict load into the model. Non-default options might cause slight overhead due to extra storage interaction. It is recommended to set this flag to ``raise_all`` first to check for expected mismatches. If mismatches are expected, set it to ``log_all`` to ignore (but log) them.


Basic Sharding
--------------

Expand Down Expand Up @@ -425,3 +530,68 @@ and using the ``dist_checkpointing.save`` and ``dist_checkpointing.load`` entryp
In Megatron Core, the sharded state dictionary preparation is already implemented in a ``sharded_state_dict`` method which creates the sharded state dicts in a composable way.
For other applications (e.g. with simpler types of supported parallelisms) it might be possible to apply a straightforward conversion from a regular model state dict into a sharded state dict.


FAQs
-----------------------

**1. Q: With the default configuration using the torch_dist checkpoint format, each rank creates two files. For example, a cluster with 576 GPUs, this results in 1152 files. Is this expected behavior?**

A: This is expected behavior for the torch_dist checkpoint.

**2. Q: When writing a checkpoint, two identical copies of the checkpoint directory are created. For example, with Llama 70B, two folders, each containing approximately 1.4TB of data, are written. Is this expected behavior?**

A: This is expected behavior in NeMo. One copy is related to the last checkpoint, while the other copy is related to the top K checkpoints.

**3. Q: Where can I find details about the Megatron binary file format and its access patterns?**

A: Please refer to the documentation at `https://pytorch.org/docs/stable/distributed.checkpoint.html <https://pytorch.org/docs/stable/distributed.checkpoint.html>`__.

**4. Q: Which `dist_ckpt` configurations are valid for pre-training and fine-tuning?**

A: All ``dist_ckpt`` configs are valid for pre-training and fine-tuning. (Note that ``dist_ckpt_load_strictness`` is not yet supported in NeMo 2.0 container 24.09).

**5. Q: What is the explanation for `-last` checkpoints?**

A: The ``-last`` checkpoint is the final checkpoint in the training session. It is used to identify the most recent checkpoint from which to continue training.

**6. Q: How does `save_top_k: 1` interact with `save_best_model`?**

A: ``save_top_k`` specifies the number of checkpoints to be saved during training. The ``save_best_model`` flag determines whether to save the best model based on a monitored metric (e.g., validation loss or accuracy).

– If ``save_top_k`` and ``save_best_model=True``: Only the single best-performing checkpoint will be retained.

– If ``save_top_k>1`` and ``save_best_model=True``: NeMo will save up to ``save_top_k`` checkpoints, and the best checkpoint (determined by the monitored metric) is always guaranteed to be included.

– If ``save_best_model=False``: NeMo will save only the top K models without explicitly ensuring that the best model is preserved.

**7. Q: How does `dist_ckpt_torch_dist_multiproc` affect the `async_save=True` parameter?**

A: ``dist_ckpt_torch_dist_multiproc`` controls distributed checkpointing by defining the number of helper processes per rank to accelerate checkpoint saving. ``async_save=True`` enables asynchronous checkpointing, allowing checkpointing processes to run in the background without blocking the main training loop. These two parameters could be used orthogonally.

**8. Q: What is the expected checkpoint saving time with the Distributed Fused Adam Optimizer or Megatron Core Distributed Optimizer? How can checkpoint saving be accelerated?**

A: The Megatron Core Distributed Optimizer is recommended and is the default setting in NeMo 2.0. With Megatron Core Distributed Optimizer (model configuration ``mcore_distributed_optim``), the expected saving time should be approximately 1 second for a single checkpoint. With Distributed Fused Adam Optimizer from Apex (model configuration ``distributed_fused_adam``), the expected saving time should be longer, estimated to be about 3 seconds for a single checkpoint.

To accelerate checkpoint saving, it is recommended to set ``dist_ckpt_assume_constant_structure=True``.


Glossary
-----------------------

DP
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Data Parallelism (DP) replicates the model across multiple GPUs. Data batches are evenly distributed between GPUs, and the data-parallel GPUs process them independently. While the computation workload is efficiently distributed across GPUs, inter-GPU communication is required to keep the model replicas consistent between training steps.

TP
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Tensor Parallelism (TP) is a model-parallel partitioning method that distributes the parameter tensor of an individual layer across GPUs. In addition to reducing model state memory usage, it also saves activation memory as the per-GPU tensor sizes shrink. However, the reduced per-GPU tensor size increases CPU overhead due to smaller per-GPU kernel workloads.

PP
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Pipeline Parallelism (PP) is a technique that assigns consecutive layers or segments of a neural network to different GPUs. This division allows each GPU to process different stages of the network sequentially.

Distributed Optimizer
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The distributed optimizer is a memory-optimized data-parallel deployment method. It shards the optimizer states and the high-precision master parameters across data-parallel GPUs instead of replicating them. At the parameter optimizer step, each data-parallel GPU updates its shard of parameters. Since each GPU needs its own gradient shard, the distributed optimizer conducts reduce-scatter of the parameter gradients instead of all-reduce of them. Then, the updated parameter shards are all-gathered across data-parallel GPUs. This approach significantly reduces the memory need of large-scale LLM training. Also, when the precision of the gradient is higher than the parameter precision, the split execution of gradient reduce-scatter and parameter all-gather can reduce the total communication volume. This split collective execution increases the total computation to overlap with the communication, which improves the overlap opportunity.

For more information, please refer to https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html.
Loading