Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MCore checkpointing optimizations #9505

Merged
merged 13 commits into from
Jul 5, 2024
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=02871b4df8c69fac687ab6676c4246e936ce92d0
ARG MCORE_TAG=0ab8dd4c7520408683fdb9f8ac119eff7d38fc0e
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand Down
4 changes: 4 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ model:
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files.

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]:
find_unused_parameters=False,
nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
sharp=self.cfg.model.get('sharp', False),
dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_save', False),
dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_dist_opt', True),
)

def _grad_scaler(self) -> GradScaler:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def dummy():
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr'))
checkpoint_io = DistributedCheckpointIO.from_config(model.cfg, async_save=False)
checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)

if HAVE_MODELOPT and hasattr(model, "get_model_module_list"):
Expand Down
4 changes: 1 addition & 3 deletions nemo/core/optim/mcore_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def load_state_dict(self, state_dict):
def sharded_state_dict(
self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False
):
# TODO(@akoumparouli, @mikolajblaz): switch to sharding_type once support for fully_sharded_model_space merged in mcore.
# sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
sharding_type = 'dp_zero_gather_scatter'
sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
return self.mcore_optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
)
Expand Down
106 changes: 83 additions & 23 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,29 @@
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase
from megatron.core.dist_checkpointing.serialization import (
get_default_load_sharded_strategy,
get_default_save_sharded_strategy,
)
from megatron.core.dist_checkpointing.strategies import tensorstore

from nemo.utils.callbacks.torch_dist_async import AsyncCallsQueue, AsyncRequest, TorchDistAsyncSaveShardedStrategy
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.parallel_state import get_data_parallel_group

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError) as IMPORT_ERROR_EXC:
except (ImportError, ModuleNotFoundError) as e:

HAVE_MEGATRON_CORE = False
IMPORT_ERROR = "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
IMPORT_ERROR = (
"megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
f" Exact error: {e}"
)


@contextmanager
Expand Down Expand Up @@ -87,7 +100,7 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO):

def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None:
if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC
raise ImportError(IMPORT_ERROR)
if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO):
raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}')

Expand Down Expand Up @@ -177,22 +190,38 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
always loads on device). Defaults to True.
async_save (bool): whether to save asynchronously. Should be set to True if
this class will be wrapped with AsyncFinalizableCheckpointIO.
torch_dist_multiproc (int, optional): number of extra processes per rank
used during ckpt save with PyTorch distributed format. Defaults, to None
which means using an MCore default (2).
parallel_save (bool): parallelizes the save across ranks. Defaults to True
parallel_load (bool): parallelizes the load across ranks (followed by params all gather).
Defaults to False due to some extra memory usage requirement.
"""

def __init__(
self,
save_ckpt_format: str,
load_directly_on_device: bool = True,
async_save: bool = False,
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = True,
parallel_load: bool = False,
):
super().__init__()
if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC
raise ImportError(IMPORT_ERROR)

self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device
self.async_save = async_save
self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy()
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_load = parallel_load

self._save_sharded_strategy = None
self.validated_consistency = False

@classmethod
def from_config(cls, model_cfg: dict, async_save: bool = False):
Expand All @@ -208,6 +237,9 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', True),
parallel_load=model_cfg.get('dist_ckpt_parallel_load', False),
)

@_debug_time('DistributedCheckpointIO.save_checkpoint')
Expand All @@ -224,16 +256,15 @@ def save_checkpoint(
fs = get_filesystem(path)
fs.makedirs(path, exist_ok=True)

dist_checkpointing.save(
sharded_state_dict=checkpoint, checkpoint_dir=path, sharded_strategy=self.save_sharded_strategy
validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure)
self.validated_consistency = True
return dist_checkpointing.save(
sharded_state_dict=checkpoint,
checkpoint_dir=path,
sharded_strategy=self.save_sharded_strategy,
validate_access_integrity=validate_sharding_integrity,
async_sharded_save=self.async_save,
)
if not self.async_save:
return None
# NOTE: this logic will be simplified in MCore v0.7
assert self.save_sharded_strategy.async_request is not None
async_request = self.save_sharded_strategy.async_request
self.save_sharded_strategy.async_request = None
return async_request

@_debug_time('DistributedCheckpointIO.load_checkpoint')
def load_checkpoint(
Expand Down Expand Up @@ -267,6 +298,16 @@ def load_checkpoint(
else:
sharded_strategy = None

if self.parallel_load:
if sharded_strategy is None:
sharded_strategy = get_default_load_sharded_strategy(path)
sharded_strategy = FullyParallelLoadStrategyWrapper(
sharded_strategy, get_data_parallel_group(with_context_parallel=True)
)

if sharded_strategy is not None:
logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')

if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)

Expand Down Expand Up @@ -309,17 +350,36 @@ def remove_checkpoint(self, path: _PATH) -> None:
"""
shutil.rmtree(path, ignore_errors=True)

@property
def save_sharded_strategy(self) -> 'SaveShardedStrategy':
if self._save_sharded_strategy is None:
self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy()
return self._save_sharded_strategy

def _determine_dist_ckpt_save_strategy(self):
"""Determine the saving strategy based on constructor args.

If self.async_save is True instantiates an async PyT Dist strategy,
otherwise relies on MCore to create a proper strategy based on ckpt format.
Relies on the default MCore strategy unless extra PyT Distributed format arguments
are passed in config or in case of a fully parallel save in which case
a parallelization wrapper is applied.
"""
save_strategy = (self.save_ckpt_format, 1)
if self.async_save:
if save_strategy[0] != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')
save_strategy = TorchDistAsyncSaveShardedStrategy('torch_dist', 1)
if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')

torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc)
if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs:
save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs)
else:
save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1)

# MCore v0.8 introduces `use_cached_ckpt_structure` attribute
if hasattr(save_strategy, 'use_cached_ckpt_structure'):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
return save_strategy
Loading
Loading