diff --git a/README.md b/README.md index 768cfc50c4dd..c7bde12dd0ea 100755 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ [![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE) [![Docker Pulls](https://img.shields.io/docker/pulls/deepspeed/deepspeed)](https://hub.docker.com/r/deepspeed/deepspeed) +### 03/2021: DeepSpeed is hiring! Come join us: [SDE 2](https://careers.microsoft.com/us/en/job/1013160/Software-Engineer-2), [Sr. SDE](https://careers.microsoft.com/us/en/job/1017151/Senior-Software-Engineer), [Sr. Researcher](https://careers.microsoft.com/us/en/job/1016440/Senior-Researcher) + [DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization library that makes distributed training easy, efficient, and effective. @@ -31,7 +33,11 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale) # News +* [2021/04/01] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed) +* [2021/03/30] [[PyTorch Lightning Blog] Accessible Multi-Billion Parameter Model Training with PyTorch Lightning + DeepSpeed](https://medium.com/pytorch-lightning/accessible-multi-billion-parameter-model-training-with-pytorch-lightning-deepspeed-c9333ac3bb59) +* [2021/03/16] [1-bit Adam v2: NCCL-based implementation and more](https://www.deepspeed.ai/tutorials/onebit-adam/) * [2021/03/08] [ZeRO-3 Offload: Scale your models to trillion parameters without code changes while leveraging both CPUs & GPUs](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html) +* [2021/01/19] [[🤗Hugging Face Blog] Fit More and Train Faster With ZeRO via DeepSpeed and FairScale](https://huggingface.co/blog/zero-deepspeed-fairscale) * [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation) * [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html) * [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/) @@ -39,7 +45,6 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale) * [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html) * [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html) * [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html) -* [2020/08/07] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) is now available on-demand # Table of Contents diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index c4c2acf0b0d7..3401f121bca0 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -49,8 +49,8 @@ def _parse_version(version_str): sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler -def initialize(args, - model, +def initialize(args=None, + model=None, optimizer=None, model_parameters=None, training_data=None, @@ -62,8 +62,7 @@ def initialize(args, """Initialize the DeepSpeed Engine. Arguments: - args: a dictionary containing local_rank and deepspeed_config - file location + args: an object containing local_rank and deepspeed_config fields. This is optional if `config_params` is passed. model: Required: nn.module class before apply any wrappers @@ -88,6 +87,9 @@ def initialize(args, mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. + config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config + as a dictionary instead. + Returns: A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler`` @@ -108,6 +110,8 @@ def initialize(args, __git_branch__), ranks=[0]) + assert model is not None, "deepspeed.initialize requires a model" + if not isinstance(model, PipelineModule): engine = DeepSpeedEngine(args=args, model=model, diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index cb80f55286d9..1da8869cc718 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -365,6 +365,12 @@ def main(args=None): result = subprocess.Popen(cmd, env=env) result.wait() + # In case of failure must propagate the error-condition back to the caller (usually shell). The + # actual error and traceback should have been printed in the subprocess, so in order to avoid + # unnecessary noise we just quietly exit here with the same code as the subprocess + if result.returncode > 0: + sys.exit(result.returncode) + if __name__ == "__main__": main() diff --git a/deepspeed/profiling/config.py b/deepspeed/profiling/config.py index 0e389baba18b..807802670654 100644 --- a/deepspeed/profiling/config.py +++ b/deepspeed/profiling/config.py @@ -3,12 +3,15 @@ Licensed under the MIT license. """ -from deepspeed.runtime.config_utils import get_scalar_param +from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject from deepspeed.profiling.constants import * -class DeepSpeedFlopsProfilerConfig(object): +class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject): def __init__(self, param_dict): + """ + docstring + """ super(DeepSpeedFlopsProfilerConfig, self).__init__() self.enabled = None @@ -24,6 +27,9 @@ def __init__(self, param_dict): self._initialize(flops_profiler_dict) def _initialize(self, flops_profiler_dict): + """ + docstring + """ self.enabled = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED_DEFAULT) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 7e225fc20f2b..be7d772782f2 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -265,7 +265,7 @@ def del_extra_repr(module): "Each module profile is listed after its name in the following order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)." ) print( - "Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n" + "Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput.\n" ) print(self.model) diff --git a/deepspeed/runtime/activation_checkpointing/config.py b/deepspeed/runtime/activation_checkpointing/config.py index 30ac5157f843..19e904980da7 100755 --- a/deepspeed/runtime/activation_checkpointing/config.py +++ b/deepspeed/runtime/activation_checkpointing/config.py @@ -3,7 +3,7 @@ Licensed under the MIT license. """ -from deepspeed.runtime.config_utils import get_scalar_param +from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject ######################################### # DeepSpeed Activation Checkpointing @@ -56,7 +56,7 @@ } -class DeepSpeedActivationCheckpointingConfig(object): +class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): def __init__(self, param_dict): super(DeepSpeedActivationCheckpointingConfig, self).__init__() @@ -74,13 +74,6 @@ def __init__(self, param_dict): self._initialize(act_chkpt_config_dict) - """ - For json serialization - """ - - def repr(self): - return self.__dict__ - def _initialize(self, act_chkpt_config_dict): self.partition_activations = get_scalar_param( act_chkpt_config_dict, diff --git a/deepspeed/runtime/comm/nccl.py b/deepspeed/runtime/comm/nccl.py index 0faefc70aa1d..b94f7caf81c5 100644 --- a/deepspeed/runtime/comm/nccl.py +++ b/deepspeed/runtime/comm/nccl.py @@ -86,7 +86,7 @@ def compressed_allreduce(self, # worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale) recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign) - # recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale) + #recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale) recvbuf_scale = [ torch.zeros(1, dtype=worker_scale.dtype, @@ -106,13 +106,13 @@ def compressed_allreduce(self, cupy_sign_list_packed = None cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign) - # cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale)) + #cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale)) compensated_server_m = self.compression_backend.cupy2torch( (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( - torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) + torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) compensated_server_m.add_(server_error) server_scale = torch.norm(compensated_server_m) / np.sqrt( compensated_server_m.numel()) @@ -172,7 +172,7 @@ def compressed_allreduce(self, (cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( - self.compression_backend.cupy2torch( + self.compression_backend.cupy2torch( cupy_recvbuf_scale_server)).flatten().data) if original_size != worker_error_size: buffer_m = buffer_m[0:original_size] diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 37f35692369b..62782852a3d2 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -5,10 +5,21 @@ """ Collection of DeepSpeed configuration utilities """ - +import json from collections import Counter +class DeepSpeedConfigObject(object): + """ + For json serialization + """ + def repr(self): + return self.__dict__ + + def __repr__(self): + return json.dumps(self.__dict__, sort_keys=True, indent=4) + + def get_scalar_param(param_dict, param_name, param_default_value): return param_dict.get(param_name, param_default_value) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 82d38fe3b1f6..4c529fe5c42e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3,10 +3,13 @@ ''' import os +import stat import torch import warnings import hashlib import torch.distributed as dist +from collections import OrderedDict +from shutil import copyfile from torch.nn.modules import Module from torch.distributed.distributed_c10d import _get_global_rank @@ -385,6 +388,9 @@ def zero_prefetch_bucket_size(self): def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold + def zero_gather_fp16_weights_on_model_save(self): + return self._config.zero_config.gather_fp16_weights_on_model_save + def fp16_enabled(self): return self._config.fp16_enabled @@ -495,9 +501,10 @@ def _configure_with_arguments(self, args, mpu): # After the distributed backend is initialized we are guaranteed the LOCAL_RANK # environment variable is set. We must align args.local_rank to this value for # backwards compatability with scripts relying on [args|self].local_rank containing - # the correct local rank info. - args.local_rank = int(os.environ['LOCAL_RANK']) - self.local_rank = args.local_rank + # the correct local rank info. _do_args_sanity_check will ensure this is the case. + self.local_rank = int(os.environ['LOCAL_RANK']) + if hasattr(args, 'local_rank'): + args.local_rank = self.local_rank config_file = args.deepspeed_config if hasattr(args, 'deepspeed_config') else None @@ -513,15 +520,14 @@ def _do_args_sanity_check(self, args): assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config - local_rank_err = "DeepSpeed requires a command line parameter of --local_rank [int] and/or setting the LOCAL_RANK environment variable." - if hasattr(args, 'local_rank'): - assert type(args.local_rank) == int, local_rank_err - if "LOCAL_RANK" in os.environ and args.local_rank >= 0: - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + assert "LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment variable, it is set by the deepspeed launcher, " \ + "deepspeed.init_distributed, or the torch.distributed launcher. If using a different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." + if hasattr(args, 'local_rank') and args.local_rank != None: + assert isinstance(args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" + if args.local_rank >= 0: + env_local_rank = int(os.environ.get("LOCAL_RANK")) assert env_local_rank == args.local_rank, \ f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." - else: - assert "LOCAL_RANK" in os.environ, local_rank_err if self.config_params is None: assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \ @@ -957,7 +963,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): Arguments: loss: Torch tensor on which to execute backward propagation - allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True. + allreduce_gradients: is deprecated, ignored, and will soon be removed' """ if not allreduce_gradients: @@ -1372,7 +1378,7 @@ def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank): zero_ckpt_name = os.path.join( checkpoints_path, str(tag), - filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt') + filename + '_mp_rank_{:02d}'.format(mp_rank) + '_optim_states.pt') return zero_ckpt_name def _get_zero_ckpt_name(self, checkpoints_path, tag): @@ -1549,13 +1555,20 @@ def _get_all_zero_checkpoints(self, load_dir, tag): mp_rank=mp_rank, dp_world_size=self.loaded_checkpoint_dp_world_size) invalid_zero_ckpt_paths = [] - for ckpt_name in zero_ckpt_names: + for i, ckpt_name in enumerate(zero_ckpt_names): if not os.path.exists(ckpt_name): + # transparently handle the old file pattern for optim_states + if 'optim_states.pt' in ckpt_name: + ckpt_name_try = ckpt_name.replace("_optim_states.pt", + "optim_states.pt") + if os.path.exists(ckpt_name_try): + zero_ckpt_names[i] = ckpt_name_try + continue invalid_zero_ckpt_paths.append(ckpt_name) if len(invalid_zero_ckpt_paths) > 0: logger.warn( - f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist" + f"The following zero checkpoints paths are missing: {invalid_zero_ckpt_paths}" ) return None @@ -1698,8 +1711,125 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): torch.save(state, save_path) self._curr_save_path = None + def _get_param_shapes(self): + param_shapes = OrderedDict() + for name, param in self.module.named_parameters(): + param_shapes[name] = param.ds_shape if hasattr(param, + "ds_shape") else param.shape + # print(f"saving param {name} {param_shapes[name]}") + return param_shapes + + def _copy_recovery_script(self, save_path): + base_dir = os.path.dirname(os.path.dirname(__file__)) + script = "zero_to_fp32.py" + src = os.path.join(base_dir, "utils", script) + dst = os.path.join(save_path, script) + logger.info(f"creating recovery script {dst}") + copyfile(src, dst) + # make executable + os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) + def _save_zero_checkpoint(self, save_path, tag): zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) - zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()} + zero_sd = dict( + optimizer_state_dict=self.optimizer.state_dict(), + param_shapes=self._get_param_shapes(), + ) torch.save(zero_sd, zero_checkpoint_name) + self._copy_recovery_script(save_path) logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) + + def _zero3_consolidated_fp16_state_dict(self): + """ + + Get a full non-partitioned state_dict with fp16 weights on cpu. + + This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but: + + 1. consolidates the weights from different partitions on gpu0 + 2. works on one layer at a time to require as little gpu0 memory as possible, by + moving the already consolidated weights to cpu + 3. takes care to keep the shared params shared when gradually copying the params to cpu + + Returns: + a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks + + """ + import deepspeed + + if not self.zero_optimization_partition_weights(): + raise ValueError("this function requires ZeRO-3 mode") + + state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None + shared_weights = {} + + def get_layer_state_dict(module, prefix=""): + # gather one layer at a time to be memory-efficient + with deepspeed.zero.GatheredParameters(list( + module.parameters(recurse=False))): + if torch.distributed.get_rank() == 0: + for name, param in module.named_parameters(recurse=False): + if param is None: + continue + key = prefix + name + # for shared weights we want to make sure not to unshare them when copying to cpu + data_ptr_id = param.storage().data_ptr() + if data_ptr_id in shared_weights: + # shared weights + # print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`") + state_dict[key] = state_dict[shared_weights[data_ptr_id]] + else: + state_dict[key] = param.detach().cpu() + shared_weights[data_ptr_id] = key + #print(f"param {name} {param.shape}") + #print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}") + + # now buffers - not sure if need to take care of potentially shared weights here + for name, buf in module.named_buffers(recurse=False): + if buf is not None and name not in module._non_persistent_buffers_set: + state_dict[prefix + name] = buf.detach().cpu() + + for name, child in module.named_children(): + if child is not None: + get_layer_state_dict(child, prefix + name + ".") + + see_memory_usage("before get_layer_state_dict", force=False) + get_layer_state_dict(self.module, prefix="") + see_memory_usage("after get_layer_state_dict", force=False) + + return state_dict + + def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): + r"""Save fp16 model weights + + This method saves the fp16 model weights at the desired destination. + + Arguments: + save_dir: Required. Directory for saving the model + save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin`` + + Important: all processes must call this method and not just the process with rank 0. It is + because the processes need to work in sync to gather the weights. This method will hang + waiting to synchronize with other processes if it's called just for the process with rank 0. + + """ + + path = os.path.join(save_dir, save_filename) + + if self.zero_optimization_partition_weights(): + if self.zero_gather_fp16_weights_on_model_save(): + # consolidation is expensive in time and memory and therefore isn't a default + state_dict = self._zero3_consolidated_fp16_state_dict() + else: + # the model will be bogus if not consolidated so don't confuse the user by saving it + logger.info( + f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False" + ) + return + else: + state_dict = self.module.state_dict() + + if torch.distributed.get_rank() == 0: + os.makedirs(save_dir, exist_ok=True) + logger.info(f"Saving model weights to {path}") + torch.save(state_dict, path) diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index bfacc0af512a..b1a7a4b0aae1 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -8,6 +8,7 @@ import os import psutil +import gc from math import ceil from math import floor from bisect import bisect_left, bisect_right @@ -551,6 +552,9 @@ def see_memory_usage(message, force=False): if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0: return + # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports + gc.collect() + # Print message except when distributed but not rank 0 logger.info(message) logger.info( @@ -564,6 +568,10 @@ def see_memory_usage(message, force=False): logger.info( f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%') + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats() + def call_to_str(base, *args, **kwargs): """Construct a string representation of a call. diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 63a0e4292bd2..622ffa9ba1cb 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -3,13 +3,12 @@ Licensed under the MIT license. """ -from deepspeed.runtime.config_utils import get_scalar_param +from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject from deepspeed.utils import logger from deepspeed.runtime.zero.constants import * -import json -class DeepSpeedZeroConfig(object): +class DeepSpeedZeroConfig(DeepSpeedConfigObject): def __init__(self, param_dict): super(DeepSpeedZeroConfig, self).__init__() @@ -35,6 +34,7 @@ def __init__(self, param_dict): self.param_persistence_threshold = None self.max_live_parameters = None self.max_reuse_distance = None + self.gather_fp16_weights_on_model_save = None #Stage3 Specific Parameters self.prefetch_bucket_size = None @@ -66,16 +66,6 @@ def read_zero_config_deprecated(self, param_dict): .format(ZERO_FORMAT)) return zero_config_dict - """ - For json serialization - """ - - def repr(self): - return self.__dict__ - - def __repr__(self): - return json.dumps(self.__dict__, sort_keys=True, indent=4) - def _initialize(self, zero_config_dict): self.stage = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_STAGE, @@ -161,3 +151,8 @@ def _initialize(self, zero_config_dict): zero_config_dict, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT) + + self.gather_fp16_weights_on_model_save = get_scalar_param( + zero_config_dict, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT) diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index 8d4cf2c5d293..e5812980a337 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -99,6 +99,10 @@ ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold' ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000 +# gathers params for saving a model - inefficient but is required in certain situations +ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save' +ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False + ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, @@ -133,5 +137,7 @@ ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE: ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD: - ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT + ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE: + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT } diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 05825fc90688..4465adfd7c16 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -279,6 +279,9 @@ def __init__(self, For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion parameter model with 4 nodes and 32 GPUs. + Important: If the fp16 weights of the model can't fit onto a single GPU memory + this feature must be used. + .. note:: Initializes ``torch.distributed`` if it has not already been done so. See :meth:`deepseed.init_distributed` for more information. @@ -807,8 +810,12 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): if start < param.ds_numel: elements = min(param.ds_numel - start, partition_size) - dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements) + dest_tensor_full_buffer = partition_buffer.view(-1).narrow( + 0, + 0, + partition_size) + dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) src_tensor = param.grad.view(-1).narrow(0, start, elements) # just copy the grad partition to the buffer @@ -841,7 +848,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): # elements)) #print("after partition gradients") - param.grad.data = dest_tensor.data + param.grad.data = dest_tensor_full_buffer.data see_memory_usage("After partitioning gradients", force=False) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index bdd1de4cbdda..cd29625958c9 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -37,7 +37,7 @@ def split_half_float_double(tensors): ] buckets = [] for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t is not None and t.type() == dtype] + bucket = [t for t in tensors if t.type() == dtype] if bucket: buckets.append(bucket) return buckets @@ -477,6 +477,8 @@ def independent_gradient_partition_epilogue(self): if self.overlap_comm: torch.cuda.synchronize() + # It is safe to clear previously reduced grads of other partitions + self._clear_previous_reduced_grads() if self.cpu_offload is False: for i, _ in enumerate(self.fp16_groups): @@ -638,6 +640,9 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): param.grad.data = new_grad_tensor.data.view_as(param.grad) self.elements_in_ipg_bucket += param.numel() + + assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" + self.grads_in_ipg_bucket.append(param.grad) self.params_in_ipg_bucket.append((i, param, param_id)) @@ -878,8 +883,12 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): for p in params: if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + # as some model have trainable parameters but skipped in training, + # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, + # so they have no norm_for_param_grads + if param_id in self.norm_for_param_grads: + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) @@ -965,7 +974,7 @@ def reduce_ipg_grads(self): if not self.is_param_in_current_partition[param_id]: if self.overlap_comm and self.contiguous_gradients is False: - # Clear the previous grads during the next reduction + # Clear grads of other partitions during the next reduction # to avoid clearing them before the reduction is complete. if self.previous_reduced_grads is None: self.previous_reduced_grads = [] @@ -1078,16 +1087,18 @@ def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=N return tensor + def _clear_previous_reduced_grads(self): + if self.previous_reduced_grads is not None: + for param in self.previous_reduced_grads: + param.grad = None + self.previous_reduced_grads = None + #if rank is specified do a reduction instead of an allreduce def allreduce_and_copy(self, small_bucket, rank=None, log=None): if self.overlap_comm: torch.cuda.synchronize() - if self.previous_reduced_grads is not None: - # previous_reduced_grads has the previous reduced grads, - # now it is safe to clear. - for param in self.previous_reduced_grads: - param.grad = None - self.previous_reduced_grads = None + # It is safe to clear the previously reduced grads of other partitions + self._clear_previous_reduced_grads() stream = self.reduction_stream else: stream = torch.cuda.current_stream() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index cc97069cb103..9168ab96d6e1 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -960,10 +960,9 @@ def _create_fp16_partitions_with_defragmentation(self): #create flat buffer in CPU and move to GPU self.fp16_partitioned_groups_flat.append( - flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) + flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i], + 1).cuda( + torch.cuda.current_device())) see_memory_usage( f"After flattening and moving param group {i} to GPU", force=False) @@ -975,10 +974,12 @@ def _create_fp16_partitions_with_defragmentation(self): flat_offset, total_elements) self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) - self._move_to_flat_buffer(self.fp16_partitioned_groups[i], - self.fp16_partitioned_groups_flat[i]) flat_offset += total_elements + # move param to flat buffer for both param offload on/off + self._move_to_flat_buffer(self.fp16_partitioned_groups[i], + self.fp16_partitioned_groups_flat[i]) + see_memory_usage(f"After Flattening param group {i}", force=False) def _create_fp32_partitions(self): @@ -1035,6 +1036,14 @@ def setup_zero_stage3_hooks(self): self.hierarchy = 0 self._register_hooks_recursively(self.module) + #reset step if in inference mode + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.param_coordinator.reset_step() + + self.module.register_forward_hook(_end_of_forward_hook) + def persistent_parameters(self): persistent_params = [] total_persistent_parameters = 0 @@ -2259,7 +2268,7 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id): assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ "averaged gradients have different number of elements that partition size {} {} {} {}".format( - single_grad_partition.numel(), self.partition_size[sub_group_id], sub_group_id, partition_id) + single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition @@ -2628,14 +2637,12 @@ def get_groups_without_padding(self, groups_with_padding): def _set_fp32_optimizer_param_groups(self): for sub_group_id, _ in enumerate(self.fp16_groups): param_group_id = self.sub_group_to_group_id[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [ - self.fp32_partitioned_groups_flat[sub_group_id] - ] + self.optimizer.param_groups[param_group_id]['params'].append( + self.fp32_partitioned_groups_flat[sub_group_id]) def _clear_fp32_optimizer_param_groups(self): - for sub_group_id, _ in enumerate(self.fp16_groups): - param_group_id = self.sub_group_to_group_id[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [] + for param_group in self.optimizer.param_groups: + param_group['params'] = [] def _rigid_state_dict(self): state_dict = {} diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py new file mode 100644 index 000000000000..3401fd635e7c --- /dev/null +++ b/deepspeed/utils/zero_to_fp32.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py global_step1 pytorch_model.bin + +import argparse +import torch +import glob +import os +from collections import OrderedDict +import deepspeed + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. + + +def get_optim_files(checkpoint_dir): + + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(f"{checkpoint_dir}/*_optim_states.pt")) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_optim_states(files): + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f)) + + if not "zero_stage" in state_dicts[0]['optimizer_state_dict']: + raise ValueError(f"non zero checkpoint") + zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"] + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = "single_partition_of_fp32_groups" + elif zero_stage == 3: + fp32_groups_key = "fp32_flat_groups" + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + param_shapes = state_dicts[0]["param_shapes"] + fp32_flat_groups = [ + state_dicts[i]['optimizer_state_dict'][fp32_groups_key][0] + for i in range(len(state_dicts)) + ] + world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] + + return zero_stage, world_size, param_shapes, fp32_flat_groups + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = int(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): + """ + Convert zero 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be + loaded with ``torch.load(file)`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the deepspeed checkpoint folder + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + + """ + print(f"Processing zero checkpoint '{checkpoint_dir}'") + + optim_files = get_optim_files(checkpoint_dir) + zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + # Reconstruction protocol: + # + # - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge + # flat buffer - no need to deal with padding since if there is any it will be only in the tail + # of the last partition so there it will be just left out + # + # - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating + # each param, while dealing with padding if any + + if zero_stage == 2: + # XXX: memory usage doubles here (zero2) + full_single_fp32_vector = torch.cat(fp32_flat_groups, 0) + + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + state_dict = OrderedDict() + offset = 0 + total_numel = 0 + for name, shape in param_shapes.items(): + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + if zero_stage == 2: + # print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + elif zero_stage == 3: + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + # print(f"{name} full shape: {shape} partition0 numel {partitioned_numel} partitioned_padding_numel {partitioned_padding_numel}") + # XXX: memory usage doubles here (zero3) + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).view(shape) + offset += partitioned_numel + partitioned_padding_numel + + # the job is done + print(f"Saving fp32 state dict to {output_file} (total_numel={total_numel})") + + torch.save(state_dict, output_file) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help= + "path to the deepspeed checkpoint folder, e.g., path/checkpoint-1/global_step1") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-1/pytorch_model.bin)" + ) + args = parser.parse_args() + + convert_zero_chkpt_to_fp32_consolid_state_dict(args.checkpoint_dir, args.output_file) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 94dad7c80bc0..81646671de47 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -68,7 +68,7 @@ GEM jekyll-theme-time-machine (= 0.1.1) jekyll-titles-from-headings (= 0.5.3) jemoji (= 0.12.0) - kramdown (= 2.3.0) + kramdown (= 2.3.1) kramdown-parser-gfm (= 1.1.0) liquid (= 4.0.3) mercenary (~> 0.3) @@ -196,7 +196,7 @@ GEM gemoji (~> 3.0) html-pipeline (~> 2.2) jekyll (>= 3.0, < 5.0) - kramdown (2.3.0) + kramdown (2.3.1) rexml kramdown-parser-gfm (1.1.0) kramdown (~> 2.0) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index d7e5adf363b4..6c34100095cd 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -57,7 +57,7 @@ The Adam optimizer also supports the following two params keys/values in additio | torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false | | adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true | - Another example of ***optimizer*** with 1-bit Adam specific parameters is as follows. + Another example of ***optimizer*** with 1-bit Adam ```json "optimizer": { @@ -71,11 +71,20 @@ The Adam optimizer also supports the following two params keys/values in additio "eps": 1e-8, "weight_decay": 3e-7, "freeze_step": 400, - "cuda_aware": true + "cuda_aware": false, + "comm_backend_name": "nccl" } } ``` +The 1-bit Adam optimizer supports the following three params keys/values in addition to the standard Adam (learn more in our [tutorial](/tutorials/onebit-adam/)): + +| "params" key | Description | Default | +| ------------- | --------------------------------------------------------------------------- | ------- | +| freeze\_step | Number of warm up steps before 1-bit compression gets applied to the communication | 100000 | +| cuda\_aware | To indicate that the underlying MPI library supports CUDA-Aware communication | false | +| comm\_backend\_name | To indicate which backend implementation to use | "nccl" | + ### Scheduler Parameters ***scheduler***: [dictionary] diff --git a/docs/_pages/features.md b/docs/_pages/features.md index 08f2bf221672..ba955fd574db 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -37,7 +37,7 @@ and communication- efficient training. DeepSpeed supports a hybrid combination of data, model, and pipeline parallelism and has scaled to over [one trillion parameters using 3D parallelism]({{ site.press_release_v3 }}). Pipeline parallelism can also improve communication efficiency and has -accelerated training by up to 7x on low-banwdith clusters. +accelerated training by up to 7x on low-bandwidth clusters. ## Model Parallelism @@ -256,9 +256,9 @@ This can be enabled by setting the following in the `deepspeed_config` file. ``` -### Timing Activiation Checkpoint Functions +### Timing Activation Checkpoint Functions -When activiation checkpoingint is enabled, profiling the forward and backward time of each checkpoint function can be enabled in the `deepspeed_config` file. +When activation checkpointing is enabled, profiling the forward and backward time of each checkpoint function can be enabled in the `deepspeed_config` file. ```json { diff --git a/docs/_tutorials/azure.md b/docs/_tutorials/azure.md index 3644b4621f8f..45d41a618a23 100644 --- a/docs/_tutorials/azure.md +++ b/docs/_tutorials/azure.md @@ -10,6 +10,8 @@ benefit all your large model training jobs. If you don't already have an Azure account please see more details here: [https://azure.microsoft.com/](https://azure.microsoft.com/). +To use DeepSpeed on [Azure ML](https://azure.microsoft.com/en-us/services/machine-learning/), please take a look at easy-to-use examples for Transformers and CIFAR training from [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed). + To help with launching Azure instances we suggest using the [Azure CLI](https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest). We have created several helper scripts to get you quickly started using DeepSpeed with Azure. diff --git a/docs/_tutorials/flops-profiler.md b/docs/_tutorials/flops-profiler.md index 3ccd8a45929f..39d0015dd4fe 100644 --- a/docs/_tutorials/flops-profiler.md +++ b/docs/_tutorials/flops-profiler.md @@ -37,11 +37,11 @@ Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k' Top 3 modules in latency at depth 2 are {'Conv2d': '11.37 ms', 'Linear': '5.27 ms', 'AvgPool2d': '5.02 ms'} ------------------------------ Detailed Profile ------------------------------ -Each module profile is listed after its name in the follwing order: +Each module profile is listed after its name in the following order: number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). Note: 1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. -2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput. LeNet5( 61.71 k, 100.00% Params, 439.56 MMACs, 100.00% MACs, 25.7 ms, 100.00% latency, 34.2 GFLOPS, @@ -92,7 +92,7 @@ The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a stan ### Usage With the DeepSpeed Runtime -When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explict API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details. +When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explicit API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details. #### Example: Megatron-LM @@ -131,11 +131,11 @@ Top 3 modules in params at depth 8 are {'ColumnParallelLinear': '7.35 M', 'RowPa Top 3 modules in latency at depth 8 are {'ColumnParallelLinear': '659.23 us', 'RowParallelLinear': '587.94 us', 'FusedScaleMaskSoftmax': '370.98 us'} ------------------------------ Detailed Profile ------------------------------ -Each module profile is listed after its name in the follwing order: +Each module profile is listed after its name in the following order: number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). Note: 1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. -2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput. DistributedDataParallel( 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.81 ms, 100.00% latency, 18.61 TFLOPS, @@ -235,11 +235,11 @@ Top 3 modules in params at depth 2 are {'Linear': '58.63 M', 'Conv2d': '2.47 M', Top 3 modules in latency at depth 2 are {'Conv2d': '13.96 ms', 'Linear': '6.23 ms', 'ReLU': '730.75 us'} ------------------------------ Detailed Profile ------------------------------ -Each module profile is listed after its name in the follwing order: +Each module profile is listed after its name in the following order: number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). Note: 1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. -2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput. AlexNet( 61.1 M, 100.00% Params, 183.18 GMACs, 100.00% MACs, 22.13 ms, 100.00% latency, 16.56 TFLOPS, @@ -335,11 +335,11 @@ Top 3 modules in params at depth 7 are {'Linear': '28.35 M', 'LayerNorm': '18.43 Top 3 modules in latency at depth 7 are {'Linear': '153.7 ms', 'LayerNorm': '4.74 ms', 'Dropout': '597.95 us'} ------------------------------ Detailed Profile ------------------------------ -Each module profile is listed after its name in the follwing order: +Each module profile is listed after its name in the following order: number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). Note: 1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. -2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput. BertForSequenceClassification( 109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.7 ms, 100.00% latency, 220.97 GFLOPS, diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index 37f104f0739e..ecd3159df8c9 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -9,6 +9,7 @@ date: 2020-05-15 * Installing is as simple as `pip install deepspeed`, [see more details](/tutorials/advanced-install/). * Please see our [Azure tutorial](/tutorials/azure/) to get started with DeepSpeed on Azure! +* To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed) * If you're not on Azure, we recommend using our docker image via `docker pull deepspeed/deepspeed:latest` which contains a pre-installed version of DeepSpeed and all the necessary dependencies. ## Writing DeepSpeed Models @@ -186,8 +187,8 @@ slots available. The following command launches a PyTorch training job across all available nodes and GPUs specified in `myhostfile`: ```bash -deepspeed \ - --deepspeed --deepspeed_config ds_config.json --hostfile=myhostfile +deepspeed --hostfile=myhostfile \ + --deepspeed --deepspeed_config ds_config.json ``` Alternatively, DeepSpeed allows you to restrict distributed training of your model to a @@ -264,3 +265,10 @@ not detected or passed in then DeepSpeed will query the number of GPUs on the local machine to discover the number of slots available. The `--include` and `--exclude` arguments work as normal, but the user should specify 'localhost' as the hostname. + +Also note that `CUDA_VISIBLE_DEVICES` can't be used with DeepSpeed to control +which devices should be used. For example, to use only gpu1 of the current +node, do: +```bash +deepspeed --include localhost:1 ... +``` diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md index 1af80cf833a8..1a15000135c9 100644 --- a/docs/_tutorials/onebit-adam.md +++ b/docs/_tutorials/onebit-adam.md @@ -7,7 +7,7 @@ This tutorial is updated on 03/04/2021 to reflect the 1-bit Adam v2. Changes inc {: .notice--info} **Watch out!** -1) The NCCL-based implementation requires PyTorch >= 1.8 and NCCL >= 2.8.3. See details below. 2) Although 1-bit Adam is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. +1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit Adam is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently 1-bit Adam is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit Adam's convergence. See details below. {: .notice--warning} In this tutorial, we are going to introduce the 1-bit Adam optimizer in DeepSpeed. 1-bit Adam can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 5x. Detailed description of the 1-bit Adam algorithm, its implementation in DeepSpeed, and performance evaluation is available from our [blog post](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html). We also have a [paper](https://arxiv.org/abs/2102.02888) which provides the most complete details including algorithm, system implementation, theoretical analysis, and more evaluations. @@ -40,7 +40,7 @@ cd DeepSpeedExamples/ In 1-bit Adam v2, we introduce a new system implementation for compressed communication using the NCCL backend of PyTorch distributed. This significantly improves the usability due to NCCL’s integration with PyTorch distributed. The performance of our new NCCL-based implementation is also better than our earlier MPI-based implementation for Ethernet-based systems and on-par for InfiniBand-based systems. Thus we highly recommend users to choose this implementation. **Watch out!** -This NCCL-based implementation requires PyTorch >= 1.8 and NCCL >= 2.8.3. Currently (2021/03/04) you need to install PyTorch 1.8 as a nightly version. Currently (2021/03/04) NCCL 2.8.3 is not officially supported by PyTorch. The solution we used is by hacking in NCCL 2.8.3 via `LD_PRELOAD`: 1) Install NCCL 2.8.3. This works for us on a CUDA 11 system: `apt-get install -y libnccl2=2.8.3-1+cuda11.0 libnccl-dev=2.8.3-1+cuda11.0`. 2) Set `LD_PRELOAD` to the the library path. This works for us: `LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnccl.so.2.8.3`. To confirm `LD_PRELOAD` is working you can see the version it uses in the NCCL logs if you have `NCCL_DEBUG=INFO`, it should say: NCCL version 2.8.3+cuda11.0. +This NCCL-based implementation requires PyTorch >= 1.8. It also requires NCCL >= 2.8.3 when you have 64 or more GPUs to avoid certain NCCL runtime bugs. Currently (2021/03/16) NCCL 2.8.3 is not officially supported by PyTorch. The solution we used is by hacking in NCCL 2.8.3 via `LD_PRELOAD`: 1) Install NCCL 2.8.3. This works for us on a CUDA 11 system: `apt-get install -y libnccl2=2.8.3-1+cuda11.0 libnccl-dev=2.8.3-1+cuda11.0`. 2) Set `LD_PRELOAD` to the the library path. This works for us: `LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnccl.so.2.8.3`. To confirm `LD_PRELOAD` is working you can see the version it uses in the NCCL logs if you have `NCCL_DEBUG=INFO`, it should say: NCCL version 2.8.3+cuda11.0. {: .notice--warning} #### 1.2.2 MPI-based implementation @@ -103,7 +103,11 @@ Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_ (New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" and "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. #### 1.4.1 (New in v2) Momentum masks for parameters with constant zero gradients -Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. +Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. + +**Watch out!** +1-bit Adam replies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence. +{: .notice--warning} ## 2. BingBertSQuAD Fine-tuning with 1-bit Adam @@ -181,7 +185,6 @@ Table 1 shows the fine-tuning configuration we used in our experiments. | Weight-decay | 0.0 | | Epoch count | 2 | | **freeze_step** | 400 | -| **cuda_aware** | false | | **comm_backend_name** | "nccl" | Table 1. Fine-tuning configuration @@ -271,7 +274,6 @@ Below is the DeepSpeed configuration file for running BERT-large pre-training wi "weight_decay": 0.01, "bias_correction": false, "freeze_step": 23000, - "cuda_aware": false, "comm_backend_name": "nccl" } }, diff --git a/docs/_tutorials/pipeline.md b/docs/_tutorials/pipeline.md index 46546066ab1a..1751846830ef 100644 --- a/docs/_tutorials/pipeline.md +++ b/docs/_tutorials/pipeline.md @@ -230,7 +230,7 @@ pipeline. Each worker should load micro-batches of size a total of `engine.gradient_accumulation_steps()` times per `train_batch()`. **Watch out!** -The pipeline engine *pulls* data from an iteratior instead of iterating over +The pipeline engine *pulls* data from an iterator instead of iterating over it. It's critical that the data stream does not empty in the middle of a training batch. Each invocation of `train_batch()` will pull a total of `engine.gradient_accumulation_steps()` micro-batches of data from @@ -276,9 +276,9 @@ For example, a machine with 16 GPUs must have as much local CPU memory as 16 tim DeepSpeed provides a `LayerSpec` class that delays the construction of modules until the model layers have been partitioned across workers. -Then each worker will allocate only the layers it's assigned to. So, continuing the -example from the previous paragraph, a machine with 16 GPUs will need to allocate a -total of 1x model size on its CPU, compared to 16x in the LayerSpec example. +Then each worker will allocate only the layers it's assigned to. So, comparing to the +example from the previous paragraph, using `LayerSpec` a machine with 16 GPUs will need to +allocate a total of 1x model size on its CPU memory and not 16x. Here is an example of the abbreviated AlexNet model, but expressed only with `LayerSpec`s. Note that the syntax is almost unchanged: `nn.ReLU(inplace=True)` diff --git a/docs/_tutorials/sparse-attention.md b/docs/_tutorials/sparse-attention.md index 915fd524e1fd..184d3e621e2d 100644 --- a/docs/_tutorials/sparse-attention.md +++ b/docs/_tutorials/sparse-attention.md @@ -154,7 +154,7 @@ This module, is the parent class for all sparsity structures and contains the sh * `block`: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such square blocks; `Block X Block`. * `different_layout_per_head`: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. -* **Fixed** (FixedSparistyConfig): +* **Fixed** (FixedSparsityConfig): This structure is based on [Generative Modeling with Sparse Transformers](https://arxiv.org/abs/1904.10509) from OpenAI, in which local and global attention is fixed by the given parameters: * `num_local_blocks`: an integer determining the number of blocks in local attention window. As it is illustrated in the below figure (adapted from original paper), tokens in a local window, attend to all tokens local to them. In the case of autoregressive model, as in the figure, tokens attend to tokens appearing before them in the local window. And in the case of Masked model such as BERT, attention is bidirectional. * `num_global_blocks`: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; illustrated in the figure below as well. diff --git a/docs/_tutorials/zero.md b/docs/_tutorials/zero.md index e594427f460f..ad6e222707e0 100644 --- a/docs/_tutorials/zero.md +++ b/docs/_tutorials/zero.md @@ -3,7 +3,7 @@ title: "Zero Redundancy Optimizer (ZeRO)" --- If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial. -In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillons of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed. +In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed. ## ZeRO Overview ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3). @@ -226,7 +226,7 @@ class ParallelTransformerLayer(MegatronModule): #### Allocating Massive Megatron-LM Models -We make two further changes to model initalization in order to support models +We make two further changes to model initialization in order to support models that exceed *local* system memory, but not *total* system memory. 1. Allocate the model in a memory-scalable fashion. The model parameters will diff --git a/docs/code-docs/source/optimizers.rst b/docs/code-docs/source/optimizers.rst index 89fc47ac547b..d7b338561b96 100755 --- a/docs/code-docs/source/optimizers.rst +++ b/docs/code-docs/source/optimizers.rst @@ -17,4 +17,4 @@ FusedLamb (GPU) OneBitAdam (GPU) ---------------------------- -.. autoclass:: deepspeed.runtime.fp16.OneBitAdam +.. autoclass:: deepspeed.runtime.fp16.onebit.adam.OneBitAdam diff --git a/docs/index.md b/docs/index.md index ee21bd3928fb..497f88bab5c3 100755 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,8 @@ toc: true toc_label: "Contents" --- +03/2021: DeepSpeed is hiring! Come join us: [SDE 2](https://careers.microsoft.com/us/en/job/1013160/Software-Engineer-2), [Sr. SDE](https://careers.microsoft.com/us/en/job/1017151/Senior-Software-Engineer), [Sr. Researcher](https://careers.microsoft.com/us/en/job/1016440/Senior-Researcher) + DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective. @@ -28,7 +30,11 @@ initiative to enable next-generation AI capabilities at scale, where you can fin information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale). # What's New? +* [2021/04/02] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed) +* [2021/03/30] [[PyTorch Lightning Blog] Accessible Multi-Billion Parameter Model Training with PyTorch Lightning + DeepSpeed](https://medium.com/pytorch-lightning/accessible-multi-billion-parameter-model-training-with-pytorch-lightning-deepspeed-c9333ac3bb59) +* [2021/03/16] [1-bit Adam v2: NCCL-based implementation and more](https://www.deepspeed.ai/tutorials/onebit-adam/) * [2021/03/08] [ZeRO-3 Offload: Scale your models to trillion parameters without code changes while leveraging both CPUs & GPUs](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html) +* [2021/01/19] [[🤗Hugging Face Blog] Fit More and Train Faster With ZeRO via DeepSpeed and FairScale](https://huggingface.co/blog/zero-deepspeed-fairscale) * [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation) * [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html) * [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone]({{ site.press_release_v3 }}) diff --git a/tests/onebit/test_com_reduce_host.py b/tests/onebit/test_com_reduce_host.py deleted file mode 100644 index 3a575828638e..000000000000 --- a/tests/onebit/test_com_reduce_host.py +++ /dev/null @@ -1,86 +0,0 @@ -from mpi4py import MPI -import time -import torch -import torch.distributed as dist -import numpy as np -import deepspeed -from deepspeed.runtime.fp16.onebit.onebitadam import OnebitAdam - -comm = MPI.COMM_WORLD -size = comm.Get_size() -rank = comm.Get_rank() - -#TODO: Detect the hostname we are running on automatically -torch.distributed.init_process_group(backend='nccl', - init_method='tcp://worker-1:2245', - world_size=size, - rank=rank) - -dummy_model = [torch.nn.Parameter(torch.ones(10))] - -# Set cuda_aware to False to use host buffers for communication -dummy_optim = OnebitAdam(dummy_model, cuda_aware=False) - -device = torch.device('cuda', rank % torch.cuda.device_count()) - - -def torch_sim(a): - a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) - scale = a.norm() / np.sqrt(a.numel()) - a_compressed = scale * a_sign - a_sign = None - worker_error = a - a_compressed - dist.all_reduce(a_compressed) - a_compressed.mul_(1 / dist.get_world_size()) - a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) - a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) - server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] - a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) - a_server_compressed = torch.cat( - [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) - rank = dist.get_rank() - server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] - torch.cuda.synchronize() - torch.distributed.barrier() - return a_server_compressed, worker_error, server_error - - -tensor_size = 100 * 2**20 -server_size = int(tensor_size / size) -if tensor_size % (8 * size) != 0: - right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) -else: - right_tensor_size = tensor_size -right_server_size = right_tensor_size // size -# Adding bias to the initialization of the gradient we are communicating -# In order to get rid of the case where some elements in the gradient are too small -a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank -worker_error = torch.zeros(right_tensor_size, device=device) -server_error = torch.zeros(right_server_size, device=device) -a_torch, worker_error_torch, server_error_torch = torch_sim(a) -torch.cuda.empty_cache() -local_rank = rank % torch.cuda.device_count() -a_after = dummy_optim.Compressed_Allreduce(a, - worker_error, - server_error, - rank, - size, - comm, - local_rank) -threshold = 1e-6 -magnitude_threshold = 1e-6 -diff_mask = (a_after - a_torch) > threshold -diff_server_mask = torch.chunk(diff_mask, size)[rank] -mpi_server = torch.chunk(a_after, size)[rank] + server_error -torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch - -# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic -# The test would skip those numbers that are too small in compensated_server_m -if torch.sum(diff_server_mask) == 0: - print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) -else: - check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold - if torch.sum(check_mag_mask) == 0: - print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) - else: - print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) diff --git a/tests/onebit/test_mpi_backend.py b/tests/onebit/test_mpi_backend.py index 6ef7df42a81d..785021cf0935 100644 --- a/tests/onebit/test_mpi_backend.py +++ b/tests/onebit/test_mpi_backend.py @@ -11,14 +11,10 @@ size = comm.Get_size() rank = comm.Get_rank() -#TODO: Detect the hostname we are running on automatically -torch.distributed.init_process_group(backend='nccl', - init_method='tcp://worker-0:2245', - world_size=size, - rank=rank) +deepspeed.init_distributed(dist_backend='nccl') # Change cuda_aware to True to test out CUDA-Aware MPI communication -backend = MpiBackend(cuda_aware=True) +backend = MpiBackend(cuda_aware=False) device = torch.device('cuda', rank % torch.cuda.device_count()) diff --git a/tests/onebit/test_mpi_perf.py b/tests/onebit/test_mpi_perf.py index 4b572c814317..6017ec873c21 100644 --- a/tests/onebit/test_mpi_perf.py +++ b/tests/onebit/test_mpi_perf.py @@ -18,13 +18,9 @@ size = comm.Get_size() rank = comm.Get_rank() -#TODO: Detect the hostname we are running on automatically -torch.distributed.init_process_group(backend='nccl', - init_method='tcp://worker-0:2245', - world_size=size, - rank=rank) - -backend = MpiBackend(cuda_aware=True) +deepspeed.init_distributed(dist_backend='nccl') +# Change cuda_aware to True to test out CUDA-Aware MPI communication +backend = MpiBackend(cuda_aware=False) device = torch.device('cuda', rank % torch.cuda.device_count()) diff --git a/tests/onebit/test_nccl_backend.py b/tests/onebit/test_nccl_backend.py index 8935977ad5a2..c3f138b221d5 100644 --- a/tests/onebit/test_nccl_backend.py +++ b/tests/onebit/test_nccl_backend.py @@ -4,6 +4,11 @@ import numpy as np import argparse import deepspeed +<<<<<<< HEAD:tests/onebit/test_com_reduce_host.py +from deepspeed.runtime.fp16.onebit.onebitadam import OnebitAdam +======= +import os +>>>>>>> ab5534fc4c0f8ca21ada321f9730d723aa31288b:tests/onebit/test_nccl_backend.py from deepspeed.runtime.comm.nccl import NcclBackend @@ -11,7 +16,8 @@ parser.add_argument('--local_rank', type=int, default=-1) args = parser.parse_args() -dist.init_process_group(backend='nccl') +deepspeed.init_distributed(dist_backend='nccl') +args.local_rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) @@ -76,9 +82,18 @@ def torch_sim(a): # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic # The test would skip those numbers that are too small in compensated_server_m +<<<<<<< HEAD:tests/onebit/test_com_reduce_host.py +if torch.sum(diff_server_mask) == 0: + print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) +else: + check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold + if torch.sum(check_mag_mask) == 0: + print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) +======= if test_correctness: if torch.sum(diff_server_mask) == 0: print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank)) +>>>>>>> ab5534fc4c0f8ca21ada321f9730d723aa31288b:tests/onebit/test_nccl_backend.py else: check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold if torch.sum(check_mag_mask) == 0: diff --git a/tests/onebit/test_nccl_perf.py b/tests/onebit/test_nccl_perf.py index c45ff205621f..1374cda4ddce 100644 --- a/tests/onebit/test_nccl_perf.py +++ b/tests/onebit/test_nccl_perf.py @@ -4,6 +4,7 @@ import numpy as np import argparse import deepspeed +import os from deepspeed.runtime.comm.nccl import NcclBackend from deepspeed.utils.timer import SynchronizedWallClockTimer @@ -15,7 +16,8 @@ parser.add_argument('--local_rank', type=int, default=-1) args = parser.parse_args() -dist.init_process_group(backend='nccl') +deepspeed.init_distributed(dist_backend='nccl') +args.local_rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) diff --git a/tests/onebit/test_server_error.py b/tests/onebit/test_server_error.py deleted file mode 100644 index e4b680a6cffb..000000000000 --- a/tests/onebit/test_server_error.py +++ /dev/null @@ -1,87 +0,0 @@ -from mpi4py import MPI -import time -import torch -import torch.distributed as dist -import numpy as np -import deepspeed -from deepspeed.runtime.fp16.onebit.onebitadam import OnebitAdam - -comm = MPI.COMM_WORLD -size = comm.Get_size() -rank = comm.Get_rank() - -torch.distributed.init_process_group(backend='nccl', - init_method='tcp://worker-0:2245', - world_size=size, - rank=rank) - -dummy_model = [torch.nn.Parameter(torch.ones(10))] -dummy_optim = OnebitAdam(dummy_model, cuda_aware=False) - -device = torch.device('cuda', rank % torch.cuda.device_count()) - - -def torch_sim(a): - a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) - scale = a.norm() / np.sqrt(a.numel()) - a_compressed = scale * a_sign - a_sign = None - worker_error = a - a_compressed - dist.all_reduce(a_compressed) - a_compressed.mul_(1 / dist.get_world_size()) - a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) - a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) - server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] - a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) - a_server_compressed = torch.cat( - [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) - rank = dist.get_rank() - server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] - torch.cuda.synchronize() - torch.distributed.barrier() - return a_server_compressed, worker_error, server_error - - -# Input Tensor size -tensor_size = 100 * 2**20 - -server_size = int(tensor_size / size) -if tensor_size % (8 * size) != 0: - right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) -else: - right_tensor_size = tensor_size - -right_server_size = right_tensor_size // size - -# The -0.5 is required for avoiding sign flips/errors -a = torch.rand(tensor_size, device=device) - 0.5 - -worker_error = torch.zeros(right_tensor_size, device=device) -server_error = torch.zeros(right_server_size, device=device) -a_torch, worker_error_torch, server_error_torch = torch_sim(a) -torch.cuda.empty_cache() -local_rank = rank % torch.cuda.device_count() - -# Test the 1-bit Adam optimizer -a_after = dummy_optim.Compressed_Allreduce(a, - worker_error, - server_error, - rank, - size, - comm, - local_rank) - -# If the error is below the threshold, it is acceptable for training -threshold = 1e-6 - -diff_pos = ((a_after - a_torch) > threshold) - -if rank == 0: - before_diff = torch.chunk(a_after - a_torch, - size)[rank] + server_error - server_error_torch - if torch.norm(before_diff) / torch.norm(torch.chunk(a_after, - size)[rank]) < threshold: - print('Successfully passed the test') - else: - print('The difference for the tensor before allgather is {}'.format( - torch.norm(before_diff))) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 0fbe354933c4..765c44c8e551 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -47,7 +47,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True): if FP16_DeepSpeedZeroOptimizer_Stage3 is not None and isinstance( saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage3): - for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat): + for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat): assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer): @@ -303,12 +303,13 @@ def _test_checkpoint_fused_optimizer(args, 'deepspeed_adam'), (3, False, - 'Adam')]) + 'Adam'), + (3, + True, + 'deepspeed_adam')]) def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") - if zero_stage == 3: - pytest.skip('Skip checkpointing tests for ZeRO3') config_dict = { "train_batch_size": 2, @@ -324,8 +325,10 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt } }, "fp16": { - "enabled": True + "enabled": True, + "initial_scale_power": 8 }, + "wall_clock_breakdown": True, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload @@ -340,9 +343,7 @@ def _test_checkpoint_zero_optimizer(args, hidden_dim, load_optimizer_states): if zero_stage == 3: - global FP16_DeepSpeedZeroOptimizer_Stage3 - from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 - with deepspeed.ScatteredParameters(zero_modules=True): + with deepspeed.zero.Init(): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @@ -371,15 +372,16 @@ def _test_checkpoint_zero_optimizer(args, 'deepspeed_adam'), (3, False, - 'Adam')]) + 'Adam'), + (3, + True, + 'deepspeed_adam')]) def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") - if zero_stage == 3: - pytest.skip('Skip checkpointing tests for ZeRO3') config_dict = { "train_batch_size": 2, @@ -413,7 +415,7 @@ def _test_checkpoint_zero_no_optimizer(args, if zero_stage == 3: global FP16_DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 - with deepspeed.ScatteredParameters(zero_modules=True): + with deepspeed.zero.Init(): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @@ -445,12 +447,13 @@ def _test_checkpoint_zero_no_optimizer(args, 'deepspeed_adam'), (3, False, - 'Adam')]) + 'Adam'), + (3, + True, + 'deepspeed_adam')]) def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") - if zero_stage == 3: - pytest.skip('Skip checkpointing tests for ZeRO3') config_dict = { "train_batch_size": 2, @@ -493,7 +496,7 @@ def _test_checkpoint_lr_scheduler(args, if zero_stage == 3: global FP16_DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 - with deepspeed.ScatteredParameters(zero_modules=True): + with deepspeed.zero.Init(): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @@ -526,14 +529,15 @@ def _test_checkpoint_lr_scheduler(args, (2, True, 'deepspeed_adam'), + (3, + False, + 'Adam'), (3, True, - 'Adam')]) + 'deepspeed_adam')]) def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") - if zero_stage == 3: - pytest.skip('Skip checkpointing tests for ZeRO3') config_dict = { "train_batch_size": 2, @@ -570,7 +574,7 @@ def _test_checkpoint_no_lr_scheduler(args, load_optimizer_states, load_lr_scheduler_states): if zero_stage == 3: - with deepspeed.ScatteredParameters(zero_modules=True): + with deepspeed.zero.Init(): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 4cabefe71a33..7de3a40fabeb 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -226,3 +226,83 @@ def _helper(): model.step() _helper() + + +def test_none_args(tmpdir): + config_dict = { + "train_batch_size": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + } + } + + @distributed_test(world_size=1) + def _helper(): + model = SimpleModel(hidden_dim=10) + model, _, _, _ = deepspeed.initialize(args=None, model=model, config_params=config_dict) + data_loader = random_dataloader(model=model, + total_samples=5, + hidden_dim=10, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + + _helper() + + +def test_no_args(tmpdir): + config_dict = { + "train_batch_size": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + } + } + + @distributed_test(world_size=1) + def _helper(): + model = SimpleModel(hidden_dim=10) + model, _, _, _ = deepspeed.initialize(model=model, config_params=config_dict) + data_loader = random_dataloader(model=model, + total_samples=5, + hidden_dim=10, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + + _helper() + + +def test_no_model(tmpdir): + config_dict = { + "train_batch_size": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + } + } + + @distributed_test(world_size=1) + def _helper(): + model = SimpleModel(hidden_dim=10) + with pytest.raises(AssertionError): + model, _, _, _ = deepspeed.initialize(model=None, config_params=config_dict) + + with pytest.raises(AssertionError): + model, _, _, _ = deepspeed.initialize(model, config_params=config_dict) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 5012614f97b0..dbd40c322be9 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -347,9 +347,6 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") - if zero_stage == 3: - pytest.skip("skip for now") - config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -371,8 +368,9 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): args = args_from_dict(tmpdir, config_dict) @distributed_test(world_size=2) - def _test_zero_static_scale(args, zero_stage): - hidden_dim = 10 + def _test_zero_static_scale(args, zero_stage, hidden_dim): + #making hidden size not divisible by DP for covering this scenario + hidden_dim = hidden_dim model = SimpleModel(hidden_dim) model, optim, _, _ = deepspeed.initialize(args=args, @@ -393,7 +391,10 @@ def _test_zero_static_scale(args, zero_stage): model.backward(loss) model.step() - _test_zero_static_scale(args=args, zero_stage=zero_stage) + #test when hidden_dim is not aligned with world size + _test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=9) + #test when hidden_dim is aligned with world size + _test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=10) def test_zero_static_scale_deprecated_format(tmpdir): diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py index 1d505b8d682f..8e0056be0cff 100644 --- a/tests/unit/test_onebit.py +++ b/tests/unit/test_onebit.py @@ -6,6 +6,7 @@ import json import os import numpy as np +import time from common import distributed_test from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args @@ -15,13 +16,6 @@ pytest.skip("NCCL-based 1-bit compression requires torch 1.8 or higher", allow_module_level=True) -try: - from apex import amp - _amp_available = True -except ImportError: - _amp_available = False -amp_available = pytest.mark.skip(_amp_available, reason="apex/amp is not installed") - def test_onebitadam_fp16_basic(tmpdir): config_dict = { @@ -105,6 +99,204 @@ def _test_onebitadam_fp32_basic(args, model, hidden_dim): _test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) +def test_onebitadam_exp_avg_mask(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask1 = torch.flatten(mask1) + optimizer_grouped_parameters = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitadam_exp_avg_mask(args, model, hidden_dim): + model, optimizer, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Test whether the momentum mask works + for v in optimizer.state.values(): + if v['exp_avg'].size() == mask1.size(): + assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" + + _test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitadam_checkpointing(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + mask2 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask2[1][col] += 1 + mask1 = torch.flatten(mask1) + mask2 = torch.flatten(mask2) + + optimizer_grouped_parameters_1 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_2 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask2 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_3 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitadam_checkpointing(mask1, mask2, args, model, hidden_dim): + model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_1) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device) + for n, batch in enumerate(data_loader): + loss = model_1(batch[0], batch[1]) + model_1.backward(loss) + model_1.step() + # Test whether momentum mask still exist after saving checkpoint + assert optimizer_1.optimizer.adam_freeze_key is True + mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + # optimizer_1.optimizer.gather_compression_errors() + model_1.save_checkpoint(save_folder, tag=None) + time.sleep(5) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" + + + model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_2) + # Test whether momentum mask stays the same after loading checkpoint + mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" + model_2.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + for v in optimizer_2.state.values(): + assert 'worker_error' not in v, f"Incorrect worker error" + assert 'server_error' not in v, f"Incorrect server error" + assert optimizer_2.optimizer.adam_freeze_key is True + + model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_3) + optimizer_3.optimizer.freeze_step = 20 + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device) + for n, batch in enumerate(data_loader): + loss = model_3(batch[0], batch[1]) + model_3.backward(loss) + model_3.step() + assert optimizer_3.optimizer.adam_freeze_key is True + # Test whether momentum mask stays the same after loading checkpoint + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" + model_3.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + for v in optimizer_3.state.values(): + assert 'worker_error' not in v, f"Incorrect worker error" + assert 'server_error' not in v, f"Incorrect server error" + assert optimizer_3.optimizer.adam_freeze_key is False + + _test_onebitadam_checkpointing(mask1, + mask2, + args=args, + model=model, + hidden_dim=hidden_dim) + + def test_compressed_allreduce_basic(tmpdir): @distributed_test(world_size=[1, 2]) def _test_compressed_allreduce_basic(): diff --git a/tests/unit/test_pipe.py b/tests/unit/test_pipe.py index 30d4314a8441..65ae0023b8ec 100755 --- a/tests/unit/test_pipe.py +++ b/tests/unit/test_pipe.py @@ -169,6 +169,7 @@ def train_cifar(model, args, num_steps=400, average_dp_losses=True, fp16=True, s return losses +@pytest.mark.skip(reason="been seeing nondeterministic failures, skipping for now") @pytest.mark.parametrize('topo', [ PipeTopo(num_pp=1, diff --git a/version.txt b/version.txt index 0b9c0199636e..e4737652ca5a 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.12 +0.3.13