Skip to content

Commit

Permalink
Merge branch 'main' into dpykhtar/data_sampler_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dimapihtar authored Jul 12, 2024
2 parents b4c54f6 + 599b60f commit 8521aa1
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 40 deletions.
6 changes: 3 additions & 3 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=de1b7c223303f6ba21e0540f27361334116efcbc
ARG MCORE_TAG=c0164bcfd4f8213a10a6b1e47ef80721a68b4fb6
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand Down Expand Up @@ -69,14 +69,14 @@ git clone https://github.com/state-spaces/mamba.git && \
git checkout v2.0.3 && \
python setup.py install && \
cd .. && \
rm -rf mamba
rm -rf mamba

git clone https://github.com/Dao-AILab/causal-conv1d && \
cd causal-conv1d && \
git checkout v1.2.2.post1 && \
python setup.py install && \
cd .. && \
rm -rf causal-conv1d
rm -rf causal-conv1d

EOF

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ model:
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files.
dist_ckpt_load_strictness: null # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (default - try loading without any check), log_all (log mismatches), raise_all (raise mismatches)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
50 changes: 22 additions & 28 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from time import time
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO
Expand All @@ -44,6 +44,7 @@
FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.parallel_state import get_data_parallel_group

HAVE_MEGATRON_CORE = True
Expand Down Expand Up @@ -188,6 +189,9 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
load_directly_on_device (bool, optional): if True, loads the weights directly
on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed
always loads on device). Defaults to True.
load_strictness (StrictHandling, optional): defines loading strictness.
If not None, overwrites the `strict` flag passed to `load_checkpoint`.
Defaults to None.
async_save (bool): whether to save asynchronously. Should be set to True if
this class will be wrapped with AsyncFinalizableCheckpointIO.
torch_dist_multiproc (int, optional): number of extra processes per rank
Expand All @@ -202,6 +206,7 @@ def __init__(
self,
save_ckpt_format: str,
load_directly_on_device: bool = True,
load_strictness: Optional['StrictHandling'] = None,
async_save: bool = False,
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
Expand All @@ -215,6 +220,7 @@ def __init__(

self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device
self.load_strictness = load_strictness
self.async_save = async_save
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
Expand All @@ -238,6 +244,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
Expand Down Expand Up @@ -275,7 +282,7 @@ def load_checkpoint(
path: _PATH,
map_location: Optional[Any] = None,
sharded_state_dict: Dict[str, Any] = None,
strict: Optional[bool] = True,
strict: Union[None, bool, 'StrictHandling'] = None,
validate_access_integrity: Optional[bool] = True,
) -> Dict[str, Any]:
"""Loads a distributed checkpoint.
Expand All @@ -287,6 +294,10 @@ def load_checkpoint(
defines the loading procedure for the distributed checkpoint.
Defaults to None to comply with the CheckpointIO interface,
but it's a required argument.
strict (bool, StrictHandling, optional): adjust load strictness. bool value
is translated to StrictHandling instance. Gets overwritten by
`self.load_strictness`. Defaults to None. If `self.load_strictness`
is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED.
Returns:
Dist[str, Any]: loaded checkpoint.
Expand All @@ -311,40 +322,23 @@ def load_checkpoint(
if sharded_strategy is not None:
logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')

if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
if isinstance(strict, bool):
strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
if self.load_strictness is not None:
# Overwrites function argument
strict = self.load_strictness
if strict is None:
# Default behavior
strict = StrictHandling.ASSUME_OK_UNEXPECTED

return dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
checkpoint_dir=path,
sharded_strategy=sharded_strategy,
validate_access_integrity=validate_access_integrity,
strict=strict,
)

def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
missing_keys = []
unexpected_keys = []

def should_remove_missing_sharded_base(x: Any):
if isinstance(x, ShardedBase):
if x.key in ckpt_sharded_metadata:
loaded_keys.append(x.key)
return False
else:
unexpected_keys.append(x.key)
return True
return False

_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict

@_debug_time('DistributedCheckpointIO.remove_checkpoint')
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove a distributed checkpoint.
Expand Down
2 changes: 1 addition & 1 deletion tutorials/asr/Confidence_Ensembles.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"\n",
"# option #2: download NeMo repo\n",
"if 'google.colab' in str(get_ipython()) or not os.path.exists(os.path.join(NEMO_DIR, \"nemo\")):\n",
" BRANCH = \"main\"\n",
" BRANCH = 'main'\n",
" !git clone -b $BRANCH https://github.com/NVIDIA/NeMo $WORKSPACE_DIR/NeMo\n",
" NEMO_DIR = os.path.join(WORKSPACE_DIR, 'NeMo')\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
"5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
"\"\"\"\n",
"\n",
"GIT_USER, GIT_BRANCH = 'NVIDIA', 'main'\n",
"GIT_USER = 'NVIDIA'\n",
"BRANCH = 'main'\n",
"\n",
"if 'google.colab' in str(get_ipython()):\n",
"\n",
Expand All @@ -56,7 +57,7 @@
" !pip install matplotlib>=3.3.2\n",
"\n",
" ## Install NeMo\n",
" !python -m pip install git+https://github.com/{GIT_USER}/NeMo.git@{GIT_BRANCH}#egg=nemo_toolkit[all]\n",
" !python -m pip install git+https://github.com/{GIT_USER}/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]\n",
"\n",
" ## Install TorchAudio\n",
" !pip install torchaudio>=0.13.0 -f https://download.pytorch.org/whl/torch_stable.html"
Expand Down Expand Up @@ -210,7 +211,7 @@
"# Copy script\n",
"get_librispeech_script = os.path.join(scripts_dir, 'get_librispeech_data.py')\n",
"if not os.path.exists(get_librispeech_script):\n",
" !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{GIT_BRANCH}/scripts/dataset_processing/get_librispeech_data.py\n",
" !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/scripts/dataset_processing/get_librispeech_data.py\n",
"\n",
"# Download the data\n",
"if not speech_dir.is_dir():\n",
Expand Down Expand Up @@ -260,7 +261,7 @@
"# Copy script\n",
"get_demand_script = os.path.join(scripts_dir, 'get_demand_data.py')\n",
"if not os.path.exists(get_demand_script):\n",
" !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{GIT_BRANCH}/scripts/dataset_processing/get_demand_data.py\n",
" !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/scripts/dataset_processing/get_demand_data.py\n",
"\n",
"if not noise_dir.is_dir():\n",
" noise_dir.mkdir(exist_ok=True)\n",
Expand Down Expand Up @@ -323,7 +324,7 @@
"# Copy script\n",
"add_noise_script = os.path.join(scripts_dir, 'add_noise.py')\n",
"if not os.path.exists(add_noise_script):\n",
" !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{GIT_BRANCH}/scripts/dataset_processing/add_noise.py\n",
" !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/scripts/dataset_processing/add_noise.py\n",
"\n",
"# Generate noisy datasets and save the noise component as well.\n",
"noisy_dir = data_dir / 'noisy'\n",
Expand Down Expand Up @@ -494,7 +495,7 @@
"config_path = config_dir / 'masking.yaml'\n",
"\n",
"if not config_path.is_file():\n",
" !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{GIT_BRANCH}/examples/audio/conf/masking.yaml -P {config_dir.as_posix()}\n",
" !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/examples/audio/conf/masking.yaml -P {config_dir.as_posix()}\n",
"\n",
"config = OmegaConf.load(config_path)\n",
"config = OmegaConf.to_container(config, resolve=True)\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorials/nlp/Token_Classification-BioMegatron.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"BRANCH='main'"
"BRANCH = 'main'"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion tutorials/nlp/lora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"outputs": [],
"source": [
"%cd /NeMo/tutorials/nlp\n",
"BRANCH='main'\n",
"BRANCH = 'main'\n",
"import os\n",
"import wget\n",
"import sys\n",
Expand Down

0 comments on commit 8521aa1

Please sign in to comment.