diff --git a/Dockerfile.ci b/Dockerfile.ci index b376aacd0bfe..ac36e6429475 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=02871b4df8c69fac687ab6676c4246e936ce92d0 +ARG MCORE_TAG=0ab8dd4c7520408683fdb9f8ac119eff7d38fc0e ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 98bf7d448845..ac1f4a37b232 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -177,6 +177,10 @@ model: dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory + dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format + dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves. + dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files. ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index 194168008dc4..f4276fd1b8f9 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -90,7 +90,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: find_unused_parameters=False, nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), sharp=self.cfg.model.get('sharp', False), - dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_save', False), + dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_dist_opt', True), ) def _grad_scaler(self) -> GradScaler: diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 43c330f257ec..ad220aaa3539 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -1007,7 +1007,7 @@ def dummy(): model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() sharded_state_dict = model.sharded_state_dict() - checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr')) + checkpoint_io = DistributedCheckpointIO.from_config(model.cfg, async_save=False) checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir) if HAVE_MODELOPT and hasattr(model, "get_model_module_list"): diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 234680f49249..9feb70cc90a1 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -58,9 +58,7 @@ def load_state_dict(self, state_dict): def sharded_state_dict( self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False ): - # TODO(@akoumparouli, @mikolajblaz): switch to sharding_type once support for fully_sharded_model_space merged in mcore. - # sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' - sharding_type = 'dp_zero_gather_scatter' + sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' return self.mcore_optimizer.sharded_state_dict( model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type ) diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 31ab0c84dd3a..65eea827e851 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -32,16 +32,29 @@ from megatron.core import dist_checkpointing from megatron.core.dist_checkpointing.dict_utils import extract_matching_values from megatron.core.dist_checkpointing.mapping import ShardedBase + from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, + ) from megatron.core.dist_checkpointing.strategies import tensorstore - - from nemo.utils.callbacks.torch_dist_async import AsyncCallsQueue, AsyncRequest, TorchDistAsyncSaveShardedStrategy + from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest + from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy + from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, + ) + from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy + from megatron.core.parallel_state import get_data_parallel_group HAVE_MEGATRON_CORE = True -except (ImportError, ModuleNotFoundError) as IMPORT_ERROR_EXC: +except (ImportError, ModuleNotFoundError) as e: HAVE_MEGATRON_CORE = False - IMPORT_ERROR = "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + IMPORT_ERROR = ( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + f" Exact error: {e}" + ) @contextmanager @@ -87,7 +100,7 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO): def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None: if not HAVE_MEGATRON_CORE: - raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC + raise ImportError(IMPORT_ERROR) if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO): raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}') @@ -177,6 +190,12 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO): always loads on device). Defaults to True. async_save (bool): whether to save asynchronously. Should be set to True if this class will be wrapped with AsyncFinalizableCheckpointIO. + torch_dist_multiproc (int, optional): number of extra processes per rank + used during ckpt save with PyTorch distributed format. Defaults, to None + which means using an MCore default (2). + parallel_save (bool): parallelizes the save across ranks. Defaults to True + parallel_load (bool): parallelizes the load across ranks (followed by params all gather). + Defaults to False due to some extra memory usage requirement. """ def __init__( @@ -184,15 +203,25 @@ def __init__( save_ckpt_format: str, load_directly_on_device: bool = True, async_save: bool = False, + torch_dist_multiproc: Optional[int] = None, + assume_constant_structure: bool = False, + parallel_save: bool = True, + parallel_load: bool = False, ): super().__init__() if not HAVE_MEGATRON_CORE: - raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC + raise ImportError(IMPORT_ERROR) self.save_ckpt_format = save_ckpt_format self.load_directly_on_device = load_directly_on_device self.async_save = async_save - self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy() + self.torch_dist_multiproc = torch_dist_multiproc + self.assume_constant_structure = assume_constant_structure + self.parallel_save = parallel_save + self.parallel_load = parallel_load + + self._save_sharded_strategy = None + self.validated_consistency = False @classmethod def from_config(cls, model_cfg: dict, async_save: bool = False): @@ -208,6 +237,9 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'), load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), async_save=async_save, + torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), + parallel_save=model_cfg.get('dist_ckpt_parallel_save', True), + parallel_load=model_cfg.get('dist_ckpt_parallel_load', False), ) @_debug_time('DistributedCheckpointIO.save_checkpoint') @@ -224,16 +256,15 @@ def save_checkpoint( fs = get_filesystem(path) fs.makedirs(path, exist_ok=True) - dist_checkpointing.save( - sharded_state_dict=checkpoint, checkpoint_dir=path, sharded_strategy=self.save_sharded_strategy + validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) + self.validated_consistency = True + return dist_checkpointing.save( + sharded_state_dict=checkpoint, + checkpoint_dir=path, + sharded_strategy=self.save_sharded_strategy, + validate_access_integrity=validate_sharding_integrity, + async_sharded_save=self.async_save, ) - if not self.async_save: - return None - # NOTE: this logic will be simplified in MCore v0.7 - assert self.save_sharded_strategy.async_request is not None - async_request = self.save_sharded_strategy.async_request - self.save_sharded_strategy.async_request = None - return async_request @_debug_time('DistributedCheckpointIO.load_checkpoint') def load_checkpoint( @@ -267,6 +298,16 @@ def load_checkpoint( else: sharded_strategy = None + if self.parallel_load: + if sharded_strategy is None: + sharded_strategy = get_default_load_sharded_strategy(path) + sharded_strategy = FullyParallelLoadStrategyWrapper( + sharded_strategy, get_data_parallel_group(with_context_parallel=True) + ) + + if sharded_strategy is not None: + logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') + if not strict: sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) @@ -309,17 +350,36 @@ def remove_checkpoint(self, path: _PATH) -> None: """ shutil.rmtree(path, ignore_errors=True) + @property + def save_sharded_strategy(self) -> 'SaveShardedStrategy': + if self._save_sharded_strategy is None: + self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy() + return self._save_sharded_strategy + def _determine_dist_ckpt_save_strategy(self): """Determine the saving strategy based on constructor args. - If self.async_save is True instantiates an async PyT Dist strategy, - otherwise relies on MCore to create a proper strategy based on ckpt format. + Relies on the default MCore strategy unless extra PyT Distributed format arguments + are passed in config or in case of a fully parallel save in which case + a parallelization wrapper is applied. """ - save_strategy = (self.save_ckpt_format, 1) - if self.async_save: - if save_strategy[0] != 'torch_dist': - raise ValueError('Async dist-ckpt save supported only for torch_dist format') - save_strategy = TorchDistAsyncSaveShardedStrategy('torch_dist', 1) + if self.async_save and self.save_ckpt_format != 'torch_dist': + raise ValueError('Async dist-ckpt save supported only for torch_dist format') + + torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc) + if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs: + save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs) + else: + save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1) + + # MCore v0.8 introduces `use_cached_ckpt_structure` attribute + if hasattr(save_strategy, 'use_cached_ckpt_structure'): + save_strategy.use_cached_ckpt_structure = self.assume_constant_structure + + if self.parallel_save: + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure + ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.') return save_strategy diff --git a/nemo/utils/callbacks/torch_dist_async.py b/nemo/utils/callbacks/torch_dist_async.py deleted file mode 100644 index 1cd226af9cdb..000000000000 --- a/nemo/utils/callbacks/torch_dist_async.py +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import deque -from pathlib import Path -from time import time -from typing import Callable, List, NamedTuple, Optional, Tuple - -import torch -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync -from megatron.core.dist_checkpointing.strategies.state_dict_saver import ( - save_state_dict_async_finalize, - save_state_dict_async_plan, -) -from megatron.core.dist_checkpointing.strategies.torch import ( - MCoreSavePlanner, - TorchDistSaveShardedStrategy, - _replace_state_dict_keys_with_sharded_keys, - mcore_to_pyt_state_dict, -) -from torch import multiprocessing as mp - -from nemo.utils import logging - - -class TorchDistAsyncSaveShardedStrategy(TorchDistSaveShardedStrategy): - """Async save strategy for the PyT Distributed format. - - NOTE: this class will be removed and replaced with an MCore version - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.async_request = None - - def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - """Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to save - checkpoint_dir (Path): checkpoint directory - - Returns: None - """ - # Translate the state dict - ( - sharded_state_dict, - flat_mapping, - rename_mapping, - ) = _replace_state_dict_keys_with_sharded_keys(sharded_state_dict, self.keep_only_main_replica) - pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) - # Use PyT saving mechanism - writer = FileSystemWriterAsync(checkpoint_dir, thread_count=self.thread_count) - - save_state_dict_ret = save_state_dict_async_plan( - pyt_state_dict, - writer, - None, - planner=MCoreSavePlanner(), - ) - self.async_request = self._get_save_and_finalize_callbacks(writer, save_state_dict_ret) - return self.async_request - - def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret): - save_fn_args = writer.get_save_function_and_args() - if save_fn_args is None: # this check can be removed with MCore v0.7 - save_fn_args = None, () - save_fn, save_args = save_fn_args - - def finalize_fn(): - save_state_dict_async_finalize(*save_state_dict_ret) - torch.distributed.barrier() - - return AsyncRequest(save_fn, save_args, [finalize_fn]) - - -class AsyncRequest(NamedTuple): - """Represents an async request that needs to be scheduled for execution. - - NOTE: this class will be removed and replaced with an MCore version - - Args: - async_fn (Callable, optional): async function to call. None represents noop. - async_fn_args (Tuple): args to pass to `async_fn`. - finalize_fns (List[Callable]): list of functions to call to finalize the request. - These functions will be called synchronously after `async_fn` is done - *on all ranks*. - """ - - async_fn: Optional[Callable] - async_fn_args: Tuple - finalize_fns: List[Callable] - is_frozen: bool = False - - def add_finalize_fn(self, fn: Callable) -> None: - """Adds a new finalize function to the request. - - Args: - fn (Callable): function to add to the async request. This function - will be called *after* existing finalization functions. - - Returns: - None - """ - if self.is_frozen: - raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest') - self.finalize_fns.append(fn) - - def execute_sync(self) -> None: - """Helper to synchronously execute the request. - - This logic is equivalent to what should happen in case of the async call. - """ - if self.async_fn is not None: - self.async_fn(*self.async_fn_args) - torch.distributed.barrier() - for finalize_fn in self.finalize_fns: - finalize_fn() - - def freeze(self) -> 'AsyncRequest': - """Freezes the async request, disallowing adding new finalization functions. - - Returns: - AsyncRequest: new async request with all same fields except for the - `is_frozen` flag. - """ - return self._replace(is_frozen=True) - - -class DistributedAsyncCaller: - """Wrapper around mp.Process that ensures correct semantic of distributed finalization. - - NOTE: this class will be removed and replaced with an MCore version - - Starts process asynchronously and allows checking if all processes on all ranks are done. - """ - - def __init__(self): - self.process: Optional[mp.Process] = None - self.start_time: Optional[float] = None - - def schedule_async_call( - self, - async_fn: Optional[Callable], - save_args: Tuple, - ) -> None: - """Spawn a process with `async_fn` as the target. - - This method must be called on all ranks. - - Args: - async_fn (Callable, optional): async function to call. If None, - no process will be started. - save_args (Tuple): async function args. - """ - if async_fn is None: - return # nothing to do - torch.cuda.synchronize() - ctx = mp.get_context('fork') - self.start_time = time() - self.process = ctx.Process( - target=async_fn, - args=save_args, - ) - self.process.start() - - def is_current_async_call_done(self, blocking=False) -> bool: - """Check if async save is finished on all ranks. - - For semantic correctness, requires rank synchronization in each check. - This method must be called on all ranks. - - Args: - blocking (bool, optional): if True, will wait until the call is done - on all ranks. Otherwise, returns immediately if at least one rank - is still active. Defaults to False. - - Returns: - bool: True if all ranks are done (immediately of after active wait - if `blocking` is True), False if at least one rank is still active. - """ - # The following takes the same overhead as torch.distributed.barrier (single integer all-reduce) - is_alive = int(self.process.is_alive()) if self.process is not None else 0 - ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device()) - logging.debug(f"[rank {torch.distributed.get_rank()}] DistributedAsyncCaller is_alive:{is_alive}") - torch.distributed.all_reduce(ten) - if ten[0] > 0 and not blocking: - return False - else: - if self.process is not None: - logging.debug(f"rank: {torch.distributed.get_rank()}, joining self.process") - self.process.join() - self.process = None - - logging.debug( - f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking" - ) - self.start_time = None - return True - - -class _ActiveAsyncRequest(NamedTuple): - """Helper to represent an active async call. - - NOTE: this class will be removed and replaced with an MCore version - - Args: - idx (int): index of the call (starting from 0) - async_caller (DistributedAsyncCaller): async caller instance that represents - the async process handling the async request - async_request (AsyncRequest): async request that is being called - """ - - idx: int - async_caller: DistributedAsyncCaller - async_request: AsyncRequest - - -class AsyncCallsQueue: - """Manages a queue of async calls. - - NOTE: this class will be removed and replaced with an MCore version - - Allows adding a new async call with `schedule_async_request` and finalizing - active calls with `maybe_finalize_async_calls`. - """ - - def __init__(self): - self.async_calls: deque[_ActiveAsyncRequest] = deque([]) - self.call_idx: int = -1 - - def schedule_async_request(self, async_request: AsyncRequest) -> int: - """Start a new async call and add it to a queue of active async calls. - - This method must be called on all ranks. - - Args: - async_request (AsyncRequest): async request to start. - - Returns: - int: index of the async call that was started. - This can help the user keep track of the async calls. - """ - self.call_idx += 1 - async_caller = DistributedAsyncCaller() - async_request = async_request.freeze() - async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args) - self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request)) - return self.call_idx - - def maybe_finalize_async_calls(self, blocking=False) -> List[int]: - """Finalizes all available calls. - - This method must be called on all ranks. - - Args: - blocking (bool, optional): if True, will wait until all active requests - are done. Otherwise, finalizes only the async request that already - finished. Defaults to False. - Returns: - List[int]: list of indices (as returned by `schedule_async_request`) - of async calls that have been successfully finalized. - """ - call_idx_finalized = [] - while self.async_calls: - next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking) - if not next_async_done: - break - call_idx, _, async_request = self.async_calls.popleft() - for finalize_fn in async_request.finalize_fns: - finalize_fn() - ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device()) - torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX) - assert ( - ten.item() == call_idx - ), 'Unmatched async calls. That probably means not all ranks are participating in async finalization' - call_idx_finalized.append(call_idx) - return call_idx_finalized - - def get_num_unfinalized_calls(self): - """Get the number of active async calls.""" - return len(self.async_calls) - - def close(self): - """Finalize all calls upon closing.""" - self.maybe_finalize_async_calls(blocking=True)