From b7272f063919943b8ab02e346265f487dcb3c986 Mon Sep 17 00:00:00 2001 From: haohanchen-yagao <54413235+haohanchen-yagao@users.noreply.github.com> Date: Thu, 13 Oct 2022 15:15:57 -0700 Subject: [PATCH] Update SMMP GPT sample (#3433) * update smp * update smp * fp16 change * minor fix * minor fix * pin transformer version * Update SMMP notebooks * update gpt2 script * update notebook * minor fix * minor fix * minor fix * minor fix * fix * update gptj script and noteboook * update memory tracker * minor fix * fix * fix gptj notebook * Update training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb Co-authored-by: Miyoung * Fix typos&expressions * reformat Co-authored-by: Miyoung Co-authored-by: Aaron Markham --- .../gpt-j/01_train_gptj_smp_notebook.ipynb | 2 +- ...in_gptj_smp_tensor_parallel_notebook.ipynb | 15 +- .../model_parallel/gpt-j/fp16/__init__.py | 30 - .../pytorch/model_parallel/gpt-j/fp16/fp16.py | 1054 ----------------- .../model_parallel/gpt-j/fp16/fp16util.py | 412 ------- .../model_parallel/gpt-j/fp16/loss_scaler.py | 274 ----- .../gpt-j/fp16/megatron/clip_grads.py | 156 --- .../gpt-j/fp16/megatron/fp16.py | 698 ----------- .../gpt-j/fp16/megatron/grad_scaler.py | 130 -- .../model_parallel/gpt-j/memory_tracker.py | 81 +- .../model_parallel/gpt-j/requirements.txt | 2 +- .../gpt-j/sharded_data_parallel_checkpoint.py | 240 ++++ .../gpt-j/train_gptj_smp_script.py | 31 +- .../train_gptj_smp_tensor_parallel_script.py | 672 +++-------- .../model_parallel/gpt2/fp16/__init__.py | 30 - .../pytorch/model_parallel/gpt2/fp16/fp16.py | 1027 ---------------- .../model_parallel/gpt2/fp16/fp16util.py | 406 ------- .../model_parallel/gpt2/fp16/loss_scaler.py | 271 ----- .../model_parallel/gpt2/memory_tracker.py | 78 +- .../model_parallel/gpt2/requirements.txt | 1 + .../gpt2/sharded_data_parallel_checkpoint.py | 240 ++++ .../gpt2/smp-train-gpt-simple.ipynb | 22 +- .../model_parallel/gpt2/train_gpt_simple.py | 658 +++------- 23 files changed, 1009 insertions(+), 5521 deletions(-) delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt-j/fp16/__init__.py delete mode 100755 training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16.py delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16util.py delete mode 100755 training/distributed_training/pytorch/model_parallel/gpt-j/fp16/loss_scaler.py delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/clip_grads.py delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/fp16.py delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/grad_scaler.py create mode 100644 training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt2/fp16/__init__.py delete mode 100755 training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16util.py delete mode 100644 training/distributed_training/pytorch/model_parallel/gpt2/fp16/loss_scaler.py create mode 100644 training/distributed_training/pytorch/model_parallel/gpt2/sharded_data_parallel_checkpoint.py diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb b/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb index c8ca985ce6..4b4a0a61bc 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb @@ -21,7 +21,7 @@ "This notebook depends on the following files and folders:\n", "\n", "1. `train_gptj_smp_script.py`: This is an entrypoint script that is passed to the PyTorch estimator in the notebook instructions. This script is responsible for end to end training of the GPT-J model with SMP. The script has additional comments at places where the SMP API is used.\n", - "2. `fp16`: This folder is used for 16-bit float training, which contains a fp16 optimizer and various fp16 utilities.\n", + "2. `memory_tracker.py`: This contains the functions to track memory usage.\n", "3. `learning_rates.py`: This contains the functions for learning rate schedule.\n", "4. `requirements.txt`: This will install the dependencies, like the right version of huggingface transformers.\n", "5. `preprocess.py`: This will download and preprocess the sst2/glue dataset.\n", diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb b/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb index 9f4d8ddb7c..516d06bf7f 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This notebook walks you through how to use the tensor parallelism feature provided by the SageMaker model parallelism library. You'll learn how to train the GPT-J model with tensor parallelism on the GLUE sst2 dataset.\n", + "This notebook walks you through how to use the tensor parallelism feature provided by the SageMaker model parallelism library. You'll learn how to run FP16 training of the GPT-J model with tensor parallelism on the GLUE sst2 dataset.\n", "\n", "## Install and Upgrade Libraries\n", "\n", @@ -82,7 +82,7 @@ "import os\n", "\n", "from sagemaker import get_execution_role\n", - "from sagemaker.huggingface import HuggingFace\n", + "from sagemaker.pytorch import PyTorch\n", "from smexperiments.experiment import Experiment\n", "from smexperiments.trial import Trial\n", "import boto3\n", @@ -611,6 +611,7 @@ " \"activation_checkpointing\": 1,\n", " \"activation_strategy\": \"each\",\n", " \"optimize\": \"speed\",\n", + " \"zipped_data\": 0,\n", " # below flag loads model and optimizer state from checkpoint_s3_uri\n", " # 'load_partial': 1,\n", "}\n", @@ -809,7 +810,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Create a SageMaker HuggingFace 🤗 Estimator\n", + "### Create a SageMaker PyTorch Estimator\n", "\n", "The following cell constructs a PyTorch estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `train_gptj_smp_tensor_parallel_script.py` file and the private preview documentation. " ] @@ -826,7 +827,7 @@ " kwargs[\"security_group_ids\"] = [fsx_security_group_id]\n", " kwargs[\"subnets\"] = [fsx_subnet]\n", "\n", - "smp_estimator = HuggingFace(\n", + "smp_estimator = PyTorch(\n", " entry_point=\"train_gptj_smp_tensor_parallel_script.py\",\n", " source_dir=os.getcwd(),\n", " role=role,\n", @@ -851,18 +852,16 @@ " \"partitions\": hyperparameters[\"pipeline_parallel_degree\"],\n", " \"shard_optimizer_state\": hyperparameters[\"shard_optimizer_state\"] > 0,\n", " \"prescaled_batch\": hyperparameters[\"prescaled_batch\"] > 0,\n", - " \"fp16_params\": hyperparameters[\"fp16\"] > 0,\n", + " \"fp16\": hyperparameters[\"fp16\"] > 0,\n", " \"optimize\": hyperparameters[\"optimize\"],\n", " \"auto_partition\": False if hyperparameters[\"manual_partition\"] else True,\n", " \"default_partition\": 0,\n", - " \"fp16_params\": hyperparameters[\"fp16\"] > 0,\n", " \"optimize\": hyperparameters[\"optimize\"],\n", " },\n", " }\n", " },\n", " },\n", - " pytorch_version=\"1.10\",\n", - " transformers_version=\"4.17\",\n", + " framework_version=\"1.12\",\n", " py_version=\"py38\",\n", " output_path=s3_output_bucket,\n", " checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,\n", diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/__init__.py b/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/__init__.py deleted file mode 100644 index 714cb2bb22..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .fp16util import ( - BN_convert_float, - network_to_half, - prep_param_lists, - model_grads_to_master_grads, - master_params_to_model_params, - tofp16, - to_python_float, - convert_module, - convert_network, - FP16Model, -) - -from .fp16 import * -from .loss_scaler import * diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16.py b/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16.py deleted file mode 100755 index d49a351d6e..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16.py +++ /dev/null @@ -1,1054 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Stable version of apex FP16 Optimizer""" -import copy - -import amp_C -import smdistributed.modelparallel.torch as smp -import torch -from apex.multi_tensor_apply import multi_tensor_applier -from smdistributed.modelparallel.torch.state_mod import state as smp_state -from smdistributed.modelparallel.torch.utils import get_distribution_axis -from torch import nn -from torch._six import inf -from torch.autograd import Variable -from torch.nn.parameter import Parameter - -from .fp16util import ( - get_pp_merged_fp32_from_fp16_param_groups, - get_tp_merged_fp32_from_fp16_param_groups, - master_params_to_model_params, - model_grads_to_master_grads, - model_params_to_master_params, - register_optimizer_hooks, -) -from .loss_scaler import DynamicLossScaler, LossScaler - -FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) -HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) - - -def load_fp16_optimizer_finetuning(model, optimizer, state_dict): - opt_state_dict = state_dict["optimizer"] - - def param_name_to_index(self): - param_id_to_index = self._param_id_to_index() - name_to_index = {} - for name, param in model.named_parameters(): - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - name_to_index[name] = param_id_to_index[param_id] - return name_to_index - - def _param_index_to_param_local(self): - param_id_to_index = self._param_id_to_index() - param_index_to_param = {} - - if not model: - return param_index_to_param - - for param in model.local_parameters(): - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - param_index_to_param[param_id_to_index[param_id]] = param - - return param_index_to_param - - def hook_fn(model, optimizer): - print(f"Inside hook_fn, loading for finetuning") - from functools import partial - - optimizer.param_name_to_index = partial(param_name_to_index, optimizer) - optimizer._param_index_to_param_local = partial(_param_index_to_param_local, optimizer) - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - - for current_group, saved_group in zip( - optimizer.fp32_from_fp16_groups, optimizer.fp32_from_fp16 - ): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - - model.register_post_partition_hook(hook_fn) - - -def _get_param_index_to_id(param_id_to_index_tp_group): - param_index_to_id_tp_group = [] - for param_id_to_index_map in param_id_to_index_tp_group: - param_index_to_id_map = {} - for param_id, param_index in param_id_to_index_map.items(): - param_index_to_id_map[param_index] = param_id - param_index_to_id_tp_group.append(param_index_to_id_map) - return param_index_to_id_tp_group - - -def save_fp16_optimizer(args, model, optimizer, partial=True): - optimizer_state_dict = {} - loss_scaler = optimizer.loss_scaler - _model = loss_scaler.model - loss_scaler.model = None - _loss_scaler = copy.deepcopy(loss_scaler) - loss_scaler.model = _model - optimizer_state_dict["loss_scaler"] = _loss_scaler - optimizer_state_dict["dynamic_loss_scale"] = optimizer.dynamic_loss_scale - optimizer_state_dict["overflow"] = optimizer.overflow - optimizer_state_dict["first_closure_call_this_step"] = optimizer.first_closure_call_this_step - cpu_fp32_from_fp16_groups = [ - [param.cpu() for param in group] for group in optimizer.fp32_from_fp16_groups - ] - if optimizer.master_params_created: - register_optimizer_hooks(model) - if partial: - optimizer_state_dict["optimizer_state_dict"] = optimizer.local_state_dict( - gather_if_shard=args.gather_if_shard > 0 - ) - if args.shard_optimizer_state and args.gather_if_shard > 0: - if smp.rdp_rank() == 0: - print( - "With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0" - ) - gathered_cpu_fp32_from_fp16_groups = [cpu_fp32_from_fp16_groups] - for src in range(1, smp.rdp_size()): - gathered_cpu_fp32_from_fp16_groups.append( - smp.recv_from(src, smp.RankType.RDP_RANK) - ) - optimizer_state_dict["fp32_from_fp16"] = gathered_cpu_fp32_from_fp16_groups - else: - smp.send(cpu_fp32_from_fp16_groups, 0, smp.RankType.RDP_RANK) - optimizer_state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups - else: - optimizer_state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups - if smp.pp_size() > 1: - print( - "WARNING: Ensure that partition decision doesnt change between runs (you can ensure this by setting use_times=False in smp config)." - "If you want to save and load with partition decision changing between runs, use full save and load instead." - ) - else: - optimizer_state_dict["optimizer_state_dict"] = optimizer.state_dict() - if smp.tp_size() > 1 and not args.shard_optimizer_state: - ( - tp_merged_fp32_from_fp16_groups, - param_name_groups, - ) = get_tp_merged_fp32_from_fp16_param_groups(optimizer, cpu_fp32_from_fp16_groups) - ( - pp_merged_fp32_from_fp16_groups, - param_name_groups, - ) = get_pp_merged_fp32_from_fp16_param_groups( - optimizer, tp_merged_fp32_from_fp16_groups, param_name_groups - ) - else: - raise ValueError( - "Loading full optimizer state is not supported, when TP is not enabled or shard_optimizer_state is enabled" - ) - optimizer_state_dict["fp32_from_fp16"] = pp_merged_fp32_from_fp16_groups - optimizer_state_dict["param_name_groups"] = param_name_groups - return optimizer_state_dict - - -def load_fp16_optimizer(args, model, optimizer, state_dict, partial=True): - opt_state_dict = state_dict["optimizer"] - - if optimizer.master_params_created: - register_optimizer_hooks(model) - - def hook_fn(model, optimizer): - optimizer.load_state_dict(opt_state_dict["optimizer_state_dict"]) - if partial: - if args.shard_optimizer_state and args.gather_if_shard > 0: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"][smp.rdp_rank()] - else: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - - for current_group, saved_group in zip( - optimizer.fp32_from_fp16_groups, optimizer.fp32_from_fp16 - ): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - - else: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - param_name_groups = opt_state_dict["param_name_groups"] - param_id_to_index = optimizer._param_id_to_index() - param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group - param_index_to_name = param_index_to_name_tp_group[smp.tp_rank()] - for group_idx, (current_group, saved_group) in enumerate( - zip(optimizer.fp32_from_fp16_groups, optimizer.fp32_from_fp16) - ): - for current in current_group: - param_id = id(current) - param_index = param_id_to_index[param_id] - param_name = param_index_to_name[param_index] - arr_index = param_name_groups[group_idx][param_name] - saved = saved_group[arr_index] - if optimizer.master_distribution_axis[param_id] is not None: - axis = optimizer.master_distribution_axis[param_id] - slice_size = saved.size(axis) // smp.tp_size() - saved = torch.narrow( - saved.data, axis, slice_size * smp.tp_rank(), slice_size - ).contiguous() - else: - saved = saved.data - current.data.copy_(saved) - - model.register_post_partition_hook(hook_fn) - - -def clip_grad_norm_fp32( - parameters, param_is_distributed, shard_optimizer_state, max_norm, norm_type=2 -): - """Clips gradient norm of an iterable of parameters whose gradients - are in fp32. - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - Returns: - Total norm of the parameters (viewed as a single vector). - """ - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - torch.cuda.set_device(smp.local_rank()) - grads = [] - grads_for_norm = [] - for param in parameters: - grad_not_none = param.grad is not None - is_not_shared = not hasattr(param, "shared") or not param.shared - is_not_tp_duplicate = smp.tp_rank() == 0 or ( - param in param_is_distributed and param_is_distributed[param] - ) - if grad_not_none: - grad = param.grad.detach() - # Make sure the grads are in fp32 - assert param.grad.type() == "torch.cuda.FloatTensor" - grads.append(grad) - if is_not_shared and is_not_tp_duplicate: - grads_for_norm.append(grad) - - # Norm parameters. - max_norm = float(max_norm) - norm_type = float(norm_type) - total_norm = torch.tensor(0.0, device=torch.device("cuda")) - - # Calculate norm. - if norm_type == inf: - if len(grads_for_norm) > 0: - total_norm = max(grad.abs().max() for grad in grads_for_norm) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - # Take max across all model-parallel GPUs. - # Reducing across all ranks since gradients may be different across data parallel ranks - # when optimizer state sharding is enabled. - group = ( - smp.get_world_process_group() if shard_optimizer_state else smp.get_mp_process_group() - ) - torch.distributed.all_reduce( - total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=group - ) - total_norm = total_norm_cuda[0].item() - - else: - if norm_type == 2.0: - dummy_overflow_buf = torch.cuda.IntTensor( - [0], device=torch.device("cuda", smp.local_rank()) - ) - # Use apex's multi-tensor applier for efficiency reasons. - # Multi-tensor applier takes a function and a list of list - # and performs the operation on that list all in one kernel. - if len(grads_for_norm) > 0: - grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False, # no per-parameter norm - ) - # Since we will be summing across data parallel groups, - # we need the pow(norm-type). - total_norm = grad_norm**norm_type - - else: - for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm**norm_type - - # Sum across all model-parallel GPUs. - group = ( - smp.get_world_process_group() if shard_optimizer_state else smp.get_mp_process_group() - ) - torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.SUM, group=group) - total_norm = total_norm.item() ** (1.0 / norm_type) - - # Scale. - if len(grads) > 0: - clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - dummy_overflow_buf = torch.cuda.IntTensor( - [0], device=torch.device("cuda", smp.local_rank()) - ) - multi_tensor_applier( - amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff - ) - - return total_norm - - -def conversion_helper(val, conversion): - """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" - if not isinstance(val, (tuple, list)): - return conversion(val) - rtn = [conversion_helper(v, conversion) for v in val] - if isinstance(val, tuple): - rtn = tuple(rtn) - return rtn - - -def fp32_to_fp16(val): - """Convert fp32 `val` to fp16""" - - def half_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, FLOAT_TYPES): - val = val.half() - return val - - return conversion_helper(val, half_conversion) - - -def fp16_to_fp32(val): - """Convert fp16 `val` to fp32""" - - def float_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, HALF_TYPES): - val = val.float() - return val - - return conversion_helper(val, float_conversion) - - -class FP16_Module(nn.Module): - def __init__(self, module): - super(FP16_Module, self).__init__() - self.add_module("module", module.half()) - - def forward(self, *inputs, **kwargs): - return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) - - def state_dict(self, destination=None, prefix="", keep_vars=False): - return self.module.state_dict(destination, prefix, keep_vars) - - def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): - return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars) - - def load_state_dict(self, state_dict, strict=True): - self.module.load_state_dict(state_dict, strict=strict) - - -class FP16_Optimizer(object): - """ - :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, - and manage static or dynamic loss scaling and master weights in a manner transparent to the user. - For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, - and changing the call to ``backward``. - - Example:: - - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - # Name the FP16_Optimizer instance to replace the existing optimizer - # (recommended but not required): - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - # loss.backward() becomes: - optimizer.backward(loss) - ... - - Example with dynamic loss scaling:: - - ... - optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) - # optional arg to control dynamic loss scaling behavior - # dynamic_loss_args={'scale_window' : 500}) - # Usually, dynamic_loss_args is not necessary. - - Args: - init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. - static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. - dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. - dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. - verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. - - ``init_optimizer`` is expected to have been constructed in the ordinary way. - It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be - named to replace ``init_optimizer``, for two reasons: - First, it means that references to the same name - later in the file will not have to change. - Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to - modify ``init_optimizer``. If you do choose a unique name for the new - :class:`FP16_Optimizer` instance, you should only work with this new instance, - because the preexisting optimizer might no longer behave as expected. - - ``init_optimizer`` may be any Pytorch optimizer. - It may contain a mixture of fp16 and fp32 parameters organized into any number of - ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will - ingest these ``param_groups`` and remember them. - - Calls to :: - - loss.backward() - - must be replaced with :: - - optimizer.backward(loss) - - because :class:`FP16_Optimizer` requires ownership of the backward pass to implement - loss scaling and copies to master gradients. - - .. note:: - Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients - are downscaled before being applied. This means that adjusting the loss scale, or using - dynamic loss scaling, should not require retuning the learning rate or any other - hyperparameters. - - - **Advanced options** - - **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. - See docstring for :attr:`step`. - - **Gradient clipping**: Use :attr:`clip_master_grads`. - - **Multiple losses**: If your model accumulates gradients from multiple losses, - this can be made more efficient by supplying ``update_master_grads=False`` - to :attr:`backward`. See docstring for :attr:`backward`. - - **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: - - print(optimizer.loss_scale) - optimizer.loss_scale = new_loss_scale - - For static loss scaling, manually adjusting the loss scale over time is a reasonable - thing to do. During later epochs, gradients may become smaller, and a - higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss - scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting - the loss scale is not recommended. - - **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in - Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` - should still work as intended. - """ - - def __init__( - self, - model, - init_optimizer, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - use_smp=False, - verbose=False, - params_have_main_grad=False, - shard_optimizer_state=False, - ): - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - - self.verbose = verbose - self.model = model - - self.optimizer = init_optimizer - # init_state_dict sets up an alternative way to cast per-param state tensors. - # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. - # init_state_dict = init_optimizer.state_dict() - - self.fp16_groups = [] - self.fp32_from_fp16_groups = [] - self.fp32_from_fp32_groups = [] - self.fp32_from_fp16_paramid_groups = [] - self.static_loss_scale = static_loss_scale - self.dynamic_loss_scale = dynamic_loss_scale - self.dynamic_loss_args = dynamic_loss_args - self.use_smp = use_smp - self.master_params_created = False - self.shard_optimizer_state = shard_optimizer_state - self.warned_set_grads_to_none = False - if not self.use_smp: - self.init_master_params() - - self.master_is_distributed = {} - self.master_distribution_axis = {} - self.params_have_main_grad = params_have_main_grad - - if self.dynamic_loss_scale: - if self.dynamic_loss_args is not None: - self.dynamic_loss_args["use_smp"] = self.use_smp - self.loss_scaler = DynamicLossScaler( - self.model, self.shard_optimizer_state, **self.dynamic_loss_args - ) - else: - self.loss_scaler = DynamicLossScaler( - self.model, self.shard_optimizer_state, use_smp=self.use_smp - ) - else: - self.loss_scaler = LossScaler( - self.model, self.shard_optimizer_state, self.static_loss_scale, use_smp=self.use_smp - ) - - def init_master_params(self): - - if self.use_smp: - torch.cuda.set_device(smp.local_rank()) - register_optimizer_hooks(self.model) - self.fp32paramid_from_fp16paramid = {} - - # only need to create contiguous buffer for fp16 params which require grads - contig_buffer_size = 0 - for param_group in self.optimizer.param_groups: - for param in param_group["params"]: - if param.requires_grad and param.type() == "torch.cuda.HalfTensor": - contig_buffer_size += param.numel() - - self.fp32_param_buffer = torch.empty( - contig_buffer_size, - device=torch.device("cuda", smp.local_rank()), - dtype=torch.float32, - requires_grad=True, - ) - offset = 0 - for i, param_group in enumerate(self.optimizer.param_groups): - self.maybe_print("FP16_Optimizer processing param group {}:".format(i)) - fp16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_fp16_params_this_group = [] - fp32_from_fp16_paramids_this_group = [] - for i, param in enumerate(param_group["params"]): - if param.requires_grad: - if param.type() == "torch.cuda.HalfTensor": - self.maybe_print( - "FP16_Optimizer received torch.cuda.HalfTensor with {}".format( - param.size() - ) - ) - fp16_params_this_group.append(param) - - with torch.no_grad(): - master_param_buffer = self.fp32_param_buffer.narrow( - 0, offset, param.numel() - ).view_as(param) - master_param_buffer.copy_(param.float()) - offset += param.numel() - - master_param = nn.Parameter( - master_param_buffer, requires_grad=param.requires_grad - ) - - self.master_is_distributed[ - master_param - ] = self.model.is_distributed_parameter(param) - self.master_distribution_axis[id(master_param)] = get_distribution_axis( - param - ) - param_group["params"][i] = master_param - fp32_from_fp16_params_this_group.append(master_param) - fp32_from_fp16_paramids_this_group.append(id(master_param)) - # Reset existing state dict key to the new master param. - # We still need to recast per-param state tensors, if any, to FP32. - if param in self.optimizer.state: - self.optimizer.state[master_param] = self.optimizer.state.pop(param) - self.fp32paramid_from_fp16paramid[id(param)] = id(master_param) - elif param.type() == "torch.cuda.FloatTensor": - self.maybe_print( - "FP16_Optimizer received torch.cuda.FloatTensor with {}".format( - param.size() - ) - ) - fp32_params_this_group.append(param) - param_group["params"][i] = param - else: - raise TypeError( - "Wrapped parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " - "Received {}".format(param.type()) - ) - self.fp16_groups.append(fp16_params_this_group) - self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) - self.fp32_from_fp16_paramid_groups.append(fp32_from_fp16_paramids_this_group) - self.fp32_from_fp32_groups.append(fp32_params_this_group) - - # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors - self.optimizer.load_state_dict(self.optimizer.state_dict()) - # alternative way to cast per-param state tensors: - # self.optimizer.load_state_dict(init_state_dict) - - self.overflow = False - self.first_closure_call_this_step = True - self.master_params_created = True - - def maybe_print(self, msg): - if self.verbose: - print(msg) - - def __getstate__(self): - raise RuntimeError("FP16_Optimizer should be serialized using state_dict().") - - def __setstate__(self, state): - raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().") - - def zero_grad(self, set_grads_to_None=False): - """ - Zero fp32 and fp16 parameter grads. - """ - # In principle, only the .grad attributes of the model params need to be zeroed, - # because gradients are copied into the FP32 master params. However, we zero - # all gradients owned by the optimizer, just to be safe: - if self.shard_optimizer_state and set_grads_to_None and not self.warned_set_grads_to_none: - print( - "WARNING: Will not set fp16 gradients to None since shard_optimizer_state is enabled." - ) - self.warned_set_grads_to_none = True - - for group in self.optimizer.param_groups: - for p in group["params"]: - if set_grads_to_None: - p.grad = None - else: - if p.grad is not None: - if p.grad.grad_fn is not None: - p.grad.detach_() - else: - p.grad.requires_grad_(False) - p.grad.zero_() - - # Zero fp16 gradients owned by the model: - for fp16_group in self.fp16_groups: - for param in fp16_group: - # if shard_optimizer_state is true, do not set fp16 grads to None since - # it will be part of the contiguous buffer - if set_grads_to_None and not self.shard_optimizer_state: - param.grad = None - else: - if param.grad is not None: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() - - def _check_overflow(self): - params = [] - for group in self.fp16_groups: - for param in group: - params.append(param) - for group in self.fp32_from_fp32_groups: - for param in group: - params.append(param) - self.overflow = self.loss_scaler.has_overflow(params) - - def _update_scale(self, has_overflow=False): - self.loss_scaler.update_scale(has_overflow) - - def _master_params_to_model_params(self): - for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - master_params_to_model_params(fp16_group, fp32_from_fp16_group) - - def _model_params_to_master_params(self): - for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - model_params_to_master_params(fp16_group, fp32_from_fp16_group) - - # To consider: Integrate distributed with this wrapper by registering a hook on each variable - # that does the overflow check, gradient copy + downscale, and fp32 - # allreduce in a different stream. - def _model_grads_to_master_grads(self, loss_scale=1.0): - for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - model_grads_to_master_grads( - fp16_group, - fp32_from_fp16_group, - loss_scale=loss_scale, - params_have_main_grad=self.params_have_main_grad, - ) - - def _downscale_master(self): - if self.loss_scale != 1.0: - for group in self.optimizer.param_groups: - grads = [p.grad for p in group["params"] if p.grad is not None] - _overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier( - amp_C.multi_tensor_scale, _overflow_buf, [grads, grads], 1.0 / self.loss_scale - ) - - def clip_master_grads(self, max_norm, norm_type=2): - """ - Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. - - Args: - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the current fp32 gradients (viewed as a single vector). - - .. warning:: - Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). - """ - if not self.overflow: - fp32_params = [] - for param_group in self.optimizer.param_groups: - for param in param_group["params"]: - fp32_params.append(param) - return clip_grad_norm_fp32( - fp32_params, - self.master_is_distributed, - self.shard_optimizer_state, - max_norm, - norm_type, - ) - else: - return -1 - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - if not self.use_smp: - state_dict = {} - state_dict["loss_scaler"] = self.loss_scaler - state_dict["dynamic_loss_scale"] = self.dynamic_loss_scale - state_dict["overflow"] = self.overflow - state_dict["first_closure_call_this_step"] = self.first_closure_call_this_step - state_dict["optimizer_state_dict"] = self.optimizer.state_dict() - state_dict["fp32_from_fp16"] = self.fp32_from_fp16_groups - return state_dict - else: - return self.optimizer.state_dict() - - def load_state_dict(self, state_dict): - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - - Example:: - - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - if not self.use_smp: - # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict["loss_scaler"] - self.dynamic_loss_scale = state_dict["dynamic_loss_scale"] - self.overflow = state_dict["overflow"] - self.first_closure_call_this_step = state_dict["first_closure_call_this_step"] - self.optimizer.load_state_dict(state_dict["optimizer_state_dict"]) - # At this point, the optimizer's references to the model's fp32 parameters are up to date. - # The optimizer's hyperparameters and internal buffers are also up to date. - # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still - # out of date. There are two options. - # 1: Refresh the master params from the model's fp16 params. - # This requires less storage but incurs precision loss. - # 2: Save and restore the fp32 master copies separately. - # We choose option 2. - # - # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device - # of their associated parameters, because it's possible those buffers might not exist yet in - # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been - # constructed in the same way as the one whose state_dict we are loading, the same master params - # are guaranteed to exist, so we can just copy_() from the saved master params. - for current_group, saved_group in zip( - self.fp32_from_fp16_groups, state_dict["fp32_from_fp16"] - ): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - else: - self.optimizer.load_state_dict(state_dict) - - def reload_model_params(self): - self._model_params_to_master_params() - - def step(self, closure=None): # could add clip option. - """ - If no closure is supplied, :attr:`step` should be called after - ``fp16_optimizer_obj.backward(loss)``. - :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to - :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params - originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run - another forward pass using their model. - - If a closure is supplied, :attr:`step` may be called without a prior call to - :attr:`backward(loss)`. - This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. - However, the user should take care that any ``loss.backward()`` call within the closure - has been replaced by ``fp16_optimizer_obj.backward(loss)``. - - Args: - closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. - - Example with closure:: - - # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an - # existing pytorch optimizer. - for input, target in dataset: - def closure(): - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - # loss.backward() becomes: - optimizer.backward(loss) - return loss - optimizer.step(closure) - - .. warning:: - Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. - - .. _`ordinary Pytorch optimizer use`: - http://pytorch.org/docs/master/optim.html#optimizer-step-closure - """ - - scale = self.loss_scaler.loss_scale - self._update_scale(self.overflow) - - if self.overflow: - self.maybe_print( - "OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}".format( - scale, self.loss_scale - ) - ) - return - - if closure is not None: - retval = self._step_with_closure(closure) - else: - retval = self.optimizer.step() - - self._master_params_to_model_params() - - return retval - - def _step_with_closure(self, closure): - def wrapped_closure(): - # helpful for debugging - # print("Calling wrapped_closure, first_closure_call_this_step = {}" - # .format(self.first_closure_call_this_step)) - if self.first_closure_call_this_step: - # We expect that the fp16 params are initially fresh on entering self.step(), - # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() - # is called within self.optimizer.step(). - self.first_closure_call_this_step = False - else: - # If self.optimizer.step() internally calls wrapped_closure more than once, - # it may update the fp32 params after each call. However, self.optimizer - # doesn't know about the fp16 params at all. If the fp32 params get updated, - # we can't rely on self.optimizer to refresh the fp16 params. We need - # to handle that manually: - self._master_params_to_model_params() - # Our API expects the user to give us ownership of the backward() call by - # replacing all calls to loss.backward() with optimizer.backward(loss). - # This requirement holds whether or not the call to backward() is made within a closure. - # If the user is properly calling optimizer.backward(loss) within "closure," - # calling closure() here will give the fp32 master params fresh gradients - # for the optimizer to play with, so all wrapped_closure needs to do is call - # closure() and return the loss. - temp_loss = closure() - while self.overflow: - scale = self.loss_scaler.loss_scale - self._update_scale(self.overflow) - self.maybe_print( - "OVERFLOW within closure! Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(scale, self.loss_scale) - ) - temp_loss = closure() - return temp_loss - - retval = self.optimizer.step(wrapped_closure) - - self.first_closure_call_this_step = True - - return retval - - def backward(self, loss, update_master_grads=True, retain_graph=False): - """ - :attr:`backward` performs the following conceptual steps: - - 1. fp32_loss = loss.float() (see first Note below) - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). - 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. - 5. Finally, master grads are divided by loss_scale. - - In this way, after :attr:`backward`, the master params have fresh gradients, - and :attr:`step` may be called. - - .. note:: - :attr:`backward` internally converts the loss to fp32 before applying the loss scale. - This provides some additional safety against overflow if the user has supplied an - fp16 loss value. - However, for maximum overflow safety, the user should - compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to - :attr:`backward`. - - .. warning:: - The gradients found in a model's leaves after the call to - :attr:`backward` should not be regarded as valid in general, - because it's possible - they have been scaled (and in the case of dynamic loss scaling, - the scale factor may change over time). - If the user wants to inspect gradients after a call to :attr:`backward`, - only the master gradients should be regarded as valid. These can be retrieved via - :attr:`inspect_master_grad_data()`. - - Args: - loss: The loss output by the user's model. loss may be either float or half (but see first Note above). - update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. - retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). - - Example:: - - # Ordinary operation: - optimizer.backward(loss) - - # Naive operation with multiple losses (technically valid, but less efficient): - # fp32 grads will be correct after the second call, but - # the first call incurs an unnecessary fp16->fp32 grad copy. - optimizer.backward(loss1) - optimizer.backward(loss2) - - # More efficient way to handle multiple losses: - # The fp16->fp32 grad copy is delayed until fp16 grads from all - # losses have been accumulated. - optimizer.backward(loss1, update_master_grads=False) - optimizer.backward(loss2, update_master_grads=False) - optimizer.update_master_grads() - """ - # To consider: try multiple backward passes using retain_grad=True to find - # a loss scale that works. After you find a loss scale that works, do a final dummy - # backward pass with retain_graph=False to tear down the graph. Doing this would avoid - # discarding the iteration, but probably wouldn't improve overall efficiency. - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - if update_master_grads: - self.update_master_grads() - - def update_master_grads(self): - """ - Copy the ``.grad`` attribute from stored references to fp16 parameters to - the ``.grad`` attribute of the fp32 master parameters that are directly - updated by the optimizer. :attr:`update_master_grads` only needs to be called if - ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. - """ - if self.dynamic_loss_scale: - self._check_overflow() - if self.overflow: - return - self._model_grads_to_master_grads(self.loss_scale) - # self._downscale_master() - - def inspect_master_grad_data(self): - """ - When running with :class:`FP16_Optimizer`, - ``.grad`` attributes of a model's fp16 leaves should not be - regarded as truthful, because they might be scaled. - After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, - the fp32 master params' ``.grad`` - attributes will contain valid gradients properly divided by the loss scale. However, - because :class:`FP16_Optimizer` flattens some parameters, accessing them may be - nonintuitive. :attr:`inspect_master_grad_data` - allows those gradients to be viewed with shapes corresponding to their associated model leaves. - - Returns: - List of lists (one list for each parameter group). The list for each parameter group - is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. - """ - if self.overflow: - print( - "Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " - "Gradients are currently invalid (may be inf, nan, or stale). Returning None." - ) - return None - else: - # The optimizer owns only references to master params. - master_grads_data = [] - for param_group in self.optimizer.param_groups: - master_grads_this_group = [] - for param in param_group["params"]: - if param.grad is not None: - master_grads_this_group.append(param.grad.data) - else: - master_grads_this_group.append(None) - master_grads_data.append(master_grads_this_group) - return master_grads_data - - # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" - - def _get_loss_scale(self): - return self.loss_scaler.loss_scale - - def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value - - loss_scale = property(_get_loss_scale, _set_loss_scale) - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16util.py b/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16util.py deleted file mode 100644 index 39975c4b86..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/fp16util.py +++ /dev/null @@ -1,412 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from torch.autograd import Variable -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C -import smdistributed.modelparallel.torch as smp -from smdistributed.modelparallel.torch.state_mod import state as smp_state - - -class tofp16(nn.Module): - """ - Utility module that implements:: - - def forward(self, input): - return input.half() - """ - - def __init__(self): - super(tofp16, self).__init__() - - def forward(self, input): - return input.half() - - -def BN_convert_float(module): - """ - Utility function for network_to_half(). - - Retained for legacy purposes. - """ - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: - module.float() - for child in module.children(): - BN_convert_float(child) - return module - - -def network_to_half(network): - """ - Convert model to half precision in a batchnorm-safe way. - - Retained for legacy purposes. It is recommended to use FP16Model. - """ - return nn.Sequential(tofp16(), BN_convert_float(network.half())) - - -def convert_module(module, dtype): - """ - Converts a module's immediate parameters and buffers to dtype. - """ - for param in module.parameters(recurse=False): - if param is not None: - if param.data.dtype.is_floating_point: - param.data = param.data.to(dtype=dtype) - if param._grad is not None and param._grad.data.dtype.is_floating_point: - param._grad.data = param._grad.data.to(dtype=dtype) - - for buf in module.buffers(recurse=False): - if buf is not None and buf.data.dtype.is_floating_point: - buf.data = buf.data.to(dtype=dtype) - - -def convert_network(network, dtype): - """ - Converts a network's parameters and buffers to dtype. - """ - for module in network.modules(): - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: - continue - convert_module(module, dtype) - return network - - -class FP16Model(nn.Module): - """ - Convert model to half precision in a batchnorm-safe way. - """ - - def __init__(self, network): - super(FP16Model, self).__init__() - self.network = convert_network(network, dtype=torch.half) - - def forward(self, *inputs): - inputs = tuple(t.half() for t in inputs) - return self.network(*inputs) - - -def backwards_debug_hook(grad): - raise RuntimeError("master_params recieved a gradient in the backward pass!") - - -def prep_param_lists(model, flat_master=False): - """ - Creates a list of FP32 master parameters for a given model, as in - `Training Neural Networks with Mixed Precision: Real Examples`_. - - Args: - model (torch.nn.Module): Existing Pytorch model - flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. - Returns: - A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. - - Example:: - - model_params, master_params = prep_param_lists(model) - - .. warning:: - Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. - - .. _`Training Neural Networks with Mixed Precision: Real Examples`: - http://on-demand.gputechconf.com/gtc/2018/video/S81012/ - """ - model_params = [param for param in model.parameters() if param.requires_grad] - - if flat_master: - # Give the user some more useful error messages - try: - # flatten_dense_tensors returns a contiguous flat array. - # http://pytorch.org/docs/master/_modules/torch/_utils.html - master_params = _flatten_dense_tensors([param.data for param in model_params]).float() - except BaseException: - print( - "Error in prep_param_lists: model may contain a mixture of parameters " - "of different types. Use flat_master=False, or use F16_Optimizer." - ) - raise - master_params = torch.nn.Parameter(master_params) - master_params.requires_grad = True - # master_params.register_hook(backwards_debug_hook) - if master_params.grad is None: - master_params.grad = master_params.new(*master_params.size()) - return model_params, [master_params] - else: - master_params = [param.clone().float().detach() for param in model_params] - for param in master_params: - param.requires_grad = True - return model_params, master_params - - -def model_grads_to_master_grads( - model_params, master_params, flat_master=False, loss_scale=1.0, params_have_main_grad=False -): - """ - Copy model gradients to master gradients. - - Args: - model_params: List of model parameters created by :func:`prep_param_lists`. - master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. - """ - if flat_master: - # The flattening may incur one more deep copy than is necessary. - master_params[0].grad.data.copy_( - _flatten_dense_tensors([p.grad.data for p in model_params]) - ) - else: - for model, master in zip(model_params, master_params): - if model.device.type == "cpu": - continue - if model.grad is not None: - if master.grad is None: - if params_have_main_grad: - # If gradient_as_bucket_view is False, this will be a copy - master.grad = model.grad.float() - else: - master.grad = Variable(master.data.new(*master.data.size())) - else: - master.grad = None - model_grads = [p.grad for p in model_params if p.grad is not None] - master_grads = [p.grad for p in master_params if p.grad is not None] - if len(model_grads) == 0 or len(master_grads) == 0: - return - _overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier( - amp_C.multi_tensor_scale, _overflow_buf, [model_grads, master_grads], 1.0 / loss_scale - ) - - -def master_params_to_model_params(model_params, master_params, flat_master=False): - """ - Copy master parameters to model parameters. - - Args: - model_params: List of model parameters created by :func:`prep_param_lists`. - master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. - """ - if flat_master: - for model, master in zip( - model_params, _unflatten_dense_tensors(master_params[0].data, model_params) - ): - model.data.copy_(master) - else: - for model, master in zip(model_params, master_params): - if model.device.type == "cpu": - continue - model.data.copy_(master.data) - - -def model_params_to_master_params(model_params, master_params, flat_master=False): - """ - Copy model params to master params - """ - if flat_master: - raise ValueError("Not supported") - else: - for model, master in zip(model_params, master_params): - if model.device.type == "cpu": - continue - master.data.copy_(model.data) - - -# Backward compatibility fixes - - -def to_python_float(t): - if hasattr(t, "item"): - return t.item() - else: - return t[0] - - -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - - -def get_tp_merged_fp32_from_fp16_param_groups(optimizer, cpu_fp32_from_fp16_groups): - def _merge_param_group_tp_group(group_idx, param_group): - result_fp32_from_fp16_param_group = [] - param_name_group = {} - for i, param in enumerate(param_group): - # for each param, obtain param_name from param using two dicts above for tp_rank 0 - param_index = param_id_to_index_tp_group[rank_0][ - fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i] - ] - param_name = param_index_to_name_tp_group[rank_0][param_index] - # obtain distribution axis for the param and check if its distributed - # axis = master_distribution_axis_tp_rank_0[fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]] - axis = master_distribution_axis_tp_rank_0.get( - fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i], None - ) - if axis is not None: - tensors = [] - for r in range(smp.tp_size()): - # if distributed, for each rank, obtain param id from index using above two dicts - param_index_r = param_name_to_index_tp_group[r][param_name] - param_id_r = param_index_to_id_tp_group[r][param_index_r] - - # search param id in fp32_from_fp16_groups_param_ids and find the index. - group_param_idx = fp32_from_fp16_paramid_groups_tp_group[r][group_idx].index( - param_id_r - ) - # use the param corresponding to the index from fp32_from_fp16_groups for concatenation along axis - tensors.append( - fp32_from_fp16_param_groups_tp_group[r][group_idx][group_param_idx] - ) - result_fp32_from_fp16_param_group.append(torch.cat(tensors, axis)) - else: - # if not distributed set tp_rank 0 param as the param - result_fp32_from_fp16_param_group.append(param) - param_name_group[param_name] = i - return result_fp32_from_fp16_param_group, param_name_group - - # get param_index_to_name all and param_name_to_index_all - param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group - param_name_to_index_tp_group = smp_state.param_name_to_index_tp_group - # get mapping of param_id_to_index_all and param_index_to_id_all - param_id_to_index = optimizer._param_id_to_index() - param_id_to_index_tp_group = smp.allgather(param_id_to_index, smp.TP_GROUP) - param_index_to_id_tp_group = _get_param_index_to_id(param_id_to_index_tp_group) - # allgather all param ids and all params for fp32_from_fp16_groups - fp32_from_fp16_paramid_groups = optimizer.fp32_from_fp16_paramid_groups - fp32_from_fp16_paramid_groups_tp_group = smp.allgather( - fp32_from_fp16_paramid_groups, smp.TP_GROUP - ) - fp32_from_fp16_param_groups_tp_group = smp.allgather(cpu_fp32_from_fp16_groups, smp.TP_GROUP) - # broadcast distribution axis from tp_rank 0 to all tp_ranks - master_distribution_axis_tp_rank_0 = None - if smp.tp_rank() == 0: - master_distribution_axis_tp_rank_0 = optimizer.master_distribution_axis - smp.broadcast(master_distribution_axis_tp_rank_0, smp.TP_GROUP) - else: - master_distribution_axis_tp_rank_0 = smp.recv_from(0, smp.RankType.TP_RANK) - - result_fp32_from_fp16_param_groups = [] - param_name_groups = [] - rank_0 = 0 - # iterate through all the params for tp_group_fp32_from_fp16_groups[rank_0] - for group_idx, param_group in enumerate(fp32_from_fp16_param_groups_tp_group[rank_0]): - result_fp32_from_fp16_param_group, param_name_group = _merge_param_group_tp_group( - group_idx, param_group - ) - result_fp32_from_fp16_param_groups.append(result_fp32_from_fp16_param_group) - param_name_groups.append(param_name_group) - return result_fp32_from_fp16_param_groups, param_name_groups - - -def get_pp_merged_fp32_from_fp16_param_groups( - optimizer, fp32_from_fp16_groups, param_name_groups=None -): - pp_group_fp32_from_fp16_groups = smp.allgather(fp32_from_fp16_groups, smp.PP_GROUP) - if param_name_groups is not None: - index_to_param_name_groups = [] - # obtain index_to_param_name mapping across tp_group - for param_name_group in param_name_groups: - index_to_param_name = {} - for param_name, index in param_name_group.items(): - index_to_param_name[index] = param_name - index_to_param_name_groups.append(index_to_param_name) - # allgather the index_to_param_name_groups across the pp_group - pp_index_to_param_name_groups = smp.allgather(index_to_param_name_groups, smp.PP_GROUP) - else: - raise ValueError("Merging is not supported when param_name_groups is None") - - pp_merged_fp32_from_fp16_groups = [] - result_param_groups = [] - - # iterate through all the groups for rank 0 - for group_idx in range(len(pp_group_fp32_from_fp16_groups[0])): - merged = [] - start_idx = 0 - result_param_group = {} - # for each group iterate through all ranks and merge the param groups across pp_ranks - for rank, group in enumerate(pp_group_fp32_from_fp16_groups): - cur_g = group[group_idx] - start_idx += len(merged) - for i, _ in enumerate(cur_g): - param_name = pp_index_to_param_name_groups[rank][group_idx][i] - if param_name in result_param_group: - raise ValueError( - "same param_name present in the param_groups of different pipeline parallel partitions" - ) - result_param_group[param_name] = i + start_idx - merged.extend(cur_g) - pp_merged_fp32_from_fp16_groups.append(merged) - result_param_groups.append(result_param_group) - return pp_merged_fp32_from_fp16_groups, result_param_groups - - -def _get_param_index_to_id(param_id_to_index_tp_group): - param_index_to_id_tp_group = [] - for param_id_to_index_map in param_id_to_index_tp_group: - param_index_to_id_map = {} - for param_id, param_index in param_id_to_index_map.items(): - param_index_to_id_map[param_index] = param_id - param_index_to_id_tp_group.append(param_index_to_id_map) - return param_index_to_id_tp_group - - -def register_optimizer_hooks(model): - def param_name_to_index(self): - param_id_to_index = self._param_id_to_index() - name_to_index = {} - if self.redefined_params: - param_gen = model.virtual_named_parameters() - else: - param_gen = model.named_parameters() - for name, param in param_gen: - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - name_to_index[name] = param_id_to_index[param_id] - return name_to_index - - def _param_index_to_param_local(self): - param_id_to_index = self._param_id_to_index() - param_index_to_param = {} - - if not model: - return param_index_to_param - - if self.redefined_params: - param_gen = model.virtual_named_parameters() - else: - param_gen = model.named_parameters() - for name, param in param_gen: - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - param_index_to_param[param_id_to_index[param_id]] = param - - return param_index_to_param - - def hook_fn(model, optimizer): - from functools import partial - - optimizer.param_name_to_index = partial(param_name_to_index, optimizer) - optimizer._param_index_to_param_local = partial(_param_index_to_param_local, optimizer) - - model.register_post_partition_hook(hook_fn) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/loss_scaler.py b/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/loss_scaler.py deleted file mode 100755 index be579d27cd..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/loss_scaler.py +++ /dev/null @@ -1,274 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C - -import smdistributed.modelparallel.torch as smp - - -def to_python_float(t): - if hasattr(t, "item"): - return t.item() - else: - return t[0] - - -class LossScaler: - """ - Class that manages a static loss scale. This class is intended to interact with - :class:`FP16_Optimizer`, and should not be directly manipulated by the user. - - Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to - :class:`FP16_Optimizer`'s constructor. - - Args: - scale (float, optional, default=1.0): The loss scale. - """ - - def __init__(self, model, shard_optimizer_state, scale=1, use_smp=False): - self.cur_scale = scale - self.model = model - self.use_smp = use_smp - self.shard_optimizer_state = shard_optimizer_state - - # `params` is a list / generator of torch.Variable - def has_overflow(self, params): - return False - - # `x` is a torch.Tensor - def _has_inf_or_nan(x): - return False - - def update_scale(self, overflow): - pass - - @property - def loss_scale(self): - return self.cur_scale - - def scale_gradient(self, module, grad_in, grad_out): - _overflow_buf = torch.cuda.IntTensor([0]) - - multi_tensor_applier( - amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale - ) - return grad_in - - def backward(self, loss, retain_graph=False): - scaled_loss = loss * self.loss_scale - if self.use_smp: - self.model.backward(scaled_loss) - else: - scaled_loss.backward() - - -class DynamicLossScaler: - """ - Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` - indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of - :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` - operates, because the default options can be changed using the - the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. - - Loss scaling is designed to combat the problem of underflowing gradients encountered at long - times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss - scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are - encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has - occurred. - :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, - and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. - If a certain number of iterations occur without overflowing gradients detected, - :class:`DynamicLossScaler` increases the loss scale once more. - In this way :class:`DynamicLossScaler` attempts to "ride the edge" of - always using the highest loss scale possible without incurring overflow. - - Args: - init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` - scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. - scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. - """ - - def __init__( - self, - model, - shard_optimizer_state, - init_scale=2**32, - scale_factor=2.0, - scale_window=1000, - min_scale=1, - delayed_shift=1, - consecutive_hysteresis=False, - use_smp=False, - ): - self.model = model - self.shard_optimizer_state = shard_optimizer_state - self.cur_scale = init_scale - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = scale_factor - self.scale_window = scale_window - self.min_scale = min_scale - self.delayed_shift = delayed_shift - self.cur_hysteresis = delayed_shift - self.consecutive_hysteresis = consecutive_hysteresis - self.use_smp = use_smp - - # `params` is a list / generator of torch.Variable - def has_overflow_serial(self, params): - for p in params: - if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): - return True - - return False - - def has_overflow(self, params): - overflow = self.has_overflow_serial(params) - # Since each model parallel GPU carries only part of the model, - # make sure overflow flag is synced across all the model parallel GPUs - overflow_gpu = torch.cuda.ByteTensor([overflow]) - group = ( - smp.get_world_process_group() - if self.shard_optimizer_state - else smp.get_mp_process_group() - ) - torch.distributed.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX, group=group) - overflow = overflow_gpu[0].item() - return bool(overflow) - - # `x` is a torch.Tensor - - def _has_inf_or_nan(x): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float("inf") or cpu_sum == -float("inf") or cpu_sum != cpu_sum: - return True - return False - - # `overflow` is boolean indicating whether the gradient overflowed - def update_scale(self, overflow): - - if not hasattr(self, "min_scale"): - self.min_scale = 1 - if not hasattr(self, "delayed_shift"): - self.delayed_shift = 1 - if not hasattr(self, "cur_hysteresis"): - self.cur_hysteresis = 1 - if not hasattr(self, "consecutive_hysteresis"): - self.consecutive_hysteresis = True - if overflow: - # self.cur_scale /= self.scale_factor - if self.delayed_shift == 1 or self.cur_hysteresis == 1: - self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale) - else: - self.cur_hysteresis -= 1 - self.last_overflow_iter = self.cur_iter - else: - if self.consecutive_hysteresis: - self.cur_hysteresis = self.delayed_shift - if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: - if not self.consecutive_hysteresis: - self.cur_hysteresis = self.delayed_shift - self.cur_scale *= self.scale_factor - self.cur_iter += 1 - - @property - def loss_scale(self): - return self.cur_scale - - def scale_gradient(self, module, grad_in, grad_out): - _overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier( - amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale - ) - return grad_in - - def backward(self, loss, retain_graph=False): - scaled_loss = loss * self.loss_scale - if self.use_smp: - self.model.backward(scaled_loss) - else: - scaled_loss.backward() - - -############################################################## -# Example usage below here -- assuming it's in a separate file -############################################################## -""" -TO-DO separate out into an example. -if __name__ == "__main__": - import torch - from torch.autograd import Variable - from dynamic_loss_scaler import DynamicLossScaler - - # N is batch size; D_in is input dimension; - # H is hidden dimension; D_out is output dimension. - N, D_in, H, D_out = 64, 1000, 100, 10 - - # Create random Tensors to hold inputs and outputs, and wrap them in Variables. - x = Variable(torch.randn(N, D_in), requires_grad=False) - y = Variable(torch.randn(N, D_out), requires_grad=False) - - w1 = Variable(torch.randn(D_in, H), requires_grad=True) - w2 = Variable(torch.randn(H, D_out), requires_grad=True) - parameters = [w1, w2] - - learning_rate = 1e-6 - optimizer = torch.optim.SGD(parameters, lr=learning_rate) - loss_scaler = DynamicLossScaler() - - for t in range(500): - y_pred = x.mm(w1).clamp(min=0).mm(w2) - loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale - print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) - print('Iter {} scaled loss: {}'.format(t, loss.data[0])) - print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) - - # Run backprop - optimizer.zero_grad() - loss.backward() - - # Check for overflow - has_overflow = DynamicLossScaler.has_overflow(parameters) - - # If no overflow, unscale grad and update as usual - if not has_overflow: - for param in parameters: - param.grad.data.mul_(1. / loss_scaler.loss_scale) - optimizer.step() - # Otherwise, don't do anything -- ie, skip iteration - else: - print('OVERFLOW!') - - # Update loss scale for next iteration - loss_scaler.update_scale(has_overflow) - -""" diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/clip_grads.py b/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/clip_grads.py deleted file mode 100644 index 191dd91f3e..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/clip_grads.py +++ /dev/null @@ -1,156 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Gradient clipping.""" - -import torch -from torch._six import inf -import smdistributed.modelparallel.torch as smp - -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C - - -def clip_grad_norm_fp32(parameters, param_is_distributed, max_norm, norm_type=2): - """Clips gradient norm of an iterable of parameters whose gradients - are in fp32. - - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the parameters (viewed as a single vector). - """ - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - torch.cuda.set_device(smp.local_rank()) - grads = [] - grads_for_norm = [] - for param in parameters: - grad_not_none = param.grad is not None - is_not_shared = not hasattr(param, "shared") or not param.shared - is_not_tp_duplicate = smp.tp_rank() == 0 or ( - param in param_is_distributed and param_is_distributed[param] - ) - if grad_not_none: - grad = param.grad.detach() - # Make sure the grads are in fp32 - assert param.grad.type() == "torch.cuda.FloatTensor" - grads.append(grad) - if grad_not_none and is_not_shared and is_not_tp_duplicate: - grads_for_norm.append(grad) - - # Norm parameters. - max_norm = float(max_norm) - norm_type = float(norm_type) - total_norm = torch.tensor(0.0, device=torch.device("cuda")) - - # Calculate norm. - if norm_type == inf: - if len(grads_for_norm) > 0: - total_norm = max(grad.abs().max() for grad in grads_for_norm) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - # Take max across all model-parallel GPUs. - torch.distributed.all_reduce( - total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=smp.get_mp_process_group() - ) - total_norm = total_norm_cuda[0].item() - - else: - if norm_type == 2.0: - dummy_overflow_buf = torch.cuda.IntTensor( - [0], device=torch.device("cuda", smp.local_rank()) - ) - # Use apex's multi-tensor applier for efficiency reasons. - # Multi-tensor applier takes a function and a list of list - # and performs the operation on that list all in one kernel. - if len(grads_for_norm) > 0: - grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False, # no per-parameter norm - ) - # Since we will be summing across data parallel groups, - # we need the pow(norm-type). - total_norm = grad_norm**norm_type - else: - for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm**norm_type - - # Sum across all model-parallel GPUs. - torch.distributed.all_reduce( - total_norm, op=torch.distributed.ReduceOp.SUM, group=smp.get_mp_process_group() - ) - total_norm = total_norm.item() ** (1.0 / norm_type) - - # Scale. - if len(grads) > 0: - clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - dummy_overflow_buf = torch.cuda.IntTensor( - [0], device=torch.device("cuda", smp.local_rank()) - ) - multi_tensor_applier( - amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff - ) - - return total_norm - - -def count_zeros_fp32(parameters): - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - total_num_zeros = 0.0 - for param in parameters: - grad_not_none = param.grad is not None - is_not_shared = not hasattr(param, "shared") or not param.shared - is_not_tp_duplicate = smp.tp_rank() == 0 or ( - param in param_is_distributed and param_is_distributed[param] - ) - if grad_not_none and is_not_shared and is_not_tp_duplicate: - grad = param.grad.detach() - num_zeros = grad.numel() - torch.count_nonzero(grad) - total_num_zeros = num_zeros + total_num_zeros - - # Sum across all model-parallel GPUs. - torch.distributed.all_reduce( - total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=smp.get_mp_process_group() - ) - total_num_zeros = total_num_zeros.item() - - return total_num_zeros diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/fp16.py b/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/fp16.py deleted file mode 100644 index 94755f5e67..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/fp16.py +++ /dev/null @@ -1,698 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Megatron optimizer.""" -from abc import ABC, abstractmethod -from contextlib import contextmanager - -import amp_C -import humanize -import smdistributed.modelparallel.torch as smp -import torch -import torch.nn as nn -from apex.multi_tensor_apply import multi_tensor_applier -from smdistributed.modelparallel.torch.state_mod import state as smp_state -from smdistributed.modelparallel.torch.utils import get_distribution_axis - -from ..fp16util import ( - get_pp_merged_fp32_from_fp16_param_groups, - get_tp_merged_fp32_from_fp16_param_groups, - register_optimizer_hooks, -) -from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 - - -def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): - """Use multi-tensor-applier to copy values from one list to another. - We don't have a blfoat16 implementation so for now if the overflow_buf - is not provided, we default back to simple loop copy to be compatible - with bfloat16.""" - if overflow_buf: - overflow_buf.fill_(0) - # Scaling with factor `1.0` is equivalent to copy. - multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0) - else: - for this_, that_ in zip(this, that): - that_.copy_(this_) - - -def _zero_grad_group_helper(group, set_to_none): - """Zero out the gradient for a group of parameters. - Note: copied from torch.optim.optimizer.""" - for param in group: - if param.grad is not None: - if set_to_none: - param.grad = None - else: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() - - -@contextmanager -def measure_additional_mem_context(): - smp.barrier() - mem_before = torch.cuda.memory_allocated(device=smp.local_rank()) - yield - import gc - - gc.collect() - gc.collect() - gc.collect() - mem_after = torch.cuda.memory_allocated(device=smp.local_rank()) - print( - f"rank is {smp.local_rank()}, memory usage is {humanize.naturalsize(mem_after - mem_before)}" - ) - smp.barrier() - - -def save_fp16_optimizer(args, model, optimizer, partial=True): - state_dict = {} - # state_dict['optimizer'] = optimizer.state_dict() - if optimizer.grad_scaler: - state_dict["grad_scaler"] = optimizer.grad_scaler.state_dict() - # state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups - - cpu_fp32_from_fp16_groups = [ - [param.cpu() for param in group] for group in optimizer.fp32_from_float16_groups - ] - if optimizer.master_params_created: - register_optimizer_hooks(model) - if partial: - state_dict["optimizer_state_dict"] = optimizer.local_state_dict() - if args.shard_optimizer_state: - if smp.rdp_rank() == 0: - print( - "With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0" - ) - gathered_cpu_fp32_from_fp16_groups = [cpu_fp32_from_fp16_groups] - for src in range(1, smp.rdp_size()): - gathered_cpu_fp32_from_fp16_groups.append( - smp.recv_from(src, smp.RankType.RDP_RANK) - ) - state_dict["fp32_from_fp16"] = gathered_cpu_fp32_from_fp16_groups - else: - smp.send(cpu_fp32_from_fp16_groups, 0, smp.RankType.RDP_RANK) - state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups - else: - state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups - if smp.pp_size() > 1: - print( - "WARNING: Ensure that partition decision doesnt change between runs (you can ensure this by setting use_times=False in smp config)." - "If you want to save and load with partition decision changing between runs, use full save and load instead." - ) - else: - state_dict["optimizer_state_dict"] = optimizer.state_dict() - if smp.tp_size() > 1 and not args.shard_optimizer_state: - ( - tp_merged_fp32_from_fp16_groups, - param_name_groups, - ) = get_tp_merged_fp32_from_fp16_param_groups(optimizer, cpu_fp32_from_fp16_groups) - ( - pp_merged_fp32_from_fp16_groups, - param_name_groups, - ) = get_pp_merged_fp32_from_fp16_param_groups( - optimizer, tp_merged_fp32_from_fp16_groups, param_name_groups - ) - else: - raise ValueError( - "Loading full optimizer state is not supported, when TP is not enabled or shard_optimizer_state is enabled" - ) - state_dict["fp32_from_fp16"] = pp_merged_fp32_from_fp16_groups - state_dict["param_name_groups"] = param_name_groups - return state_dict - - -def load_fp16_optimizer(args, model, optimizer, state_dict, partial=True): - opt_state_dict = state_dict["optimizer"] - - if optimizer.master_params_created: - register_optimizer_hooks(model) - - def hook_fn(model, optimizer): - optimizer.load_state_dict(opt_state_dict["optimizer_state_dict"]) - if partial: - if args.shard_optimizer_state: - assert isinstance( - opt_state_dict["fp32_from_fp16"], list - ), "Loading with shard_optimizer_state=True must use the checkpoint that was trained with shard_optimizer_state=True!" - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"][smp.rdp_rank()] - else: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - - for current_group, saved_group in zip( - optimizer.fp32_from_float16_groups, optimizer.fp32_from_fp16 - ): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - - else: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - param_name_groups = opt_state_dict["param_name_groups"] - param_id_to_index = optimizer._param_id_to_index() - param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group - param_index_to_name = param_index_to_name_tp_group[smp.tp_rank()] - for group_idx, (current_group, saved_group) in enumerate( - zip(optimizer.fp32_from_float16_groups, optimizer.fp32_from_fp16) - ): - for current in current_group: - param_id = id(current) - param_index = param_id_to_index[param_id] - param_name = param_index_to_name[param_index] - arr_index = param_name_groups[group_idx][param_name] - saved = saved_group[arr_index] - if optimizer.master_distribution_axis[param_id] is not None: - axis = optimizer.master_distribution_axis[param_id] - slice_size = saved.size(axis) // smp.tp_size() - saved = torch.narrow( - saved.data, axis, slice_size * smp.tp_rank(), slice_size - ).contiguous() - else: - saved = saved.data - current.data.copy_(saved) - - optimizer.grad_scaler.load_state_dict(opt_state_dict["grad_scaler"]) - - model.register_post_partition_hook(hook_fn) - - -class MegatronOptimizer(ABC): - def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad): - """Input optimizer is the base optimizer for example Adam.""" - self.optimizer = optimizer - assert self.optimizer, "no optimizer is provided." - # Set gradient clipping and logging params. - self.clip_grad = clip_grad - self.log_num_zeros_in_grad = log_num_zeros_in_grad - self.params_have_main_grad = params_have_main_grad - - def get_parameters(self): - params = [] - for param_group in self.optimizer.param_groups: - for param in param_group["params"]: - params.append(param) - return params - - def clip_grad_norm(self, clip_grad): - params = self.get_parameters() - return clip_grad_norm_fp32(params, self.master_is_distributed, clip_grad) - - def count_zeros(self): - params = self.get_parameters() - return count_zeros_fp32(params) - - @abstractmethod - def zero_grad(self, set_grads_to_None=True): - pass - - @abstractmethod - def get_loss_scale(self): - """The output should be a cuda tensor of size 1.""" - - def scale_loss(self, loss): - """Simple scaling.""" - return self.get_loss_scale() * loss - - @abstractmethod - def step(self): - pass - - @abstractmethod - def reload_model_params(self): - """Refreshes any internal state from the current model parameters. - Call whenever the parameters are changed outside of the optimizer. - For example, when we load a model from a checkpoint without loading - the optimizer, the model parameters are updated but for fp16 optimizer - with main parameters, the main parameters need to also be updated.""" - - @abstractmethod - def state_dict(self): - pass - - @abstractmethod - def load_state_dict(self, state_dict): - pass - - # Promote state so it can be retrieved or set via - # "optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via - # "optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - - -class FP16_Module(nn.Module): - def __init__(self, module): - super(FP16_Module, self).__init__() - self.add_module("module", module.half()) - - def forward(self, *inputs, **kwargs): - return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) - - def state_dict(self, destination=None, prefix="", keep_vars=False): - return self.module.state_dict(destination, prefix, keep_vars) - - def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): - return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars) - - def load_state_dict(self, state_dict, strict=True): - self.module.load_state_dict(state_dict, strict=strict) - - -class Float16OptimizerWithFloat16Params(MegatronOptimizer): - """Float16 optimizer for fp16 and bf16 data types. - - Arguments: - optimizer: base optimizer such as Adam or SGD - clip_grad: clip gradeints with this global L2 norm. Note - that clipping is ignored if clip_grad == 0 - log_num_zeros_in_grad: return number of zeros in the gradients. - params_have_main_grad: flag indicating if parameters have - a `main_grad` field. If this is set, we are assuming - that the model parameters are store in the `main_grad` - field instead of the typical `grad` field. This happens - for the DDP cases where there is a contihuous buffer - holding the gradients. For example for bfloat16, we want - to do gradient accumulation and all-reduces in float32 - and as a result we store those gradients in the main_grad. - Note that main grad is not necessarily in float32. - bf16: if true, the model is running in bfloat16. - grad_scaler: used for scaling gradients. Note that this can be - None. This case happens when `bf16 = True` and we don't - use any loss scale. Note that for `bf16 = True`, we can have - a constnat gradient scaler. Also for `bf16 = False`, we - always require a grad scaler. - """ - - def __init__( - self, - model, - optimizer, - clip_grad, - log_num_zeros_in_grad, - params_have_main_grad, - bf16, - grad_scaler, - use_smp=False, - shard_optimizer_state=False, - verbose=False, - ): - - super(Float16OptimizerWithFloat16Params, self).__init__( - optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad - ) - - self.model = model - self.verbose = verbose - self.use_smp = use_smp - self.shard_optimizer_state = shard_optimizer_state - self.bf16 = bf16 - self.grad_scaler = grad_scaler - - # None grad scaler is only supported for bf16. - if self.grad_scaler is None: - assert self.bf16, "fp16 expects a grad scaler." - - # Three groups of parameters: - # float16_groups: original float16 parameters - # fp32_from_float16_groups: fp32 copy of float16 parameters - # fp32_from_fp32_groups: original fp32 parameters - self.float16_groups = [] - self.fp32_from_float16_groups = [] - self.fp32_from_fp32_groups = [] - self.master_is_distributed = {} - self.fp32_from_fp16_paramid_groups = [] - self.master_distribution_axis = {} - self.master_params_created = False - self.warned_set_grads_to_none = False - if not self.use_smp: - self.init_master_params() - - def measure_additional_mem(f): - def wrapper(*args, **kwargs): - mem_before = torch.cuda.memory_allocated(device=smp.local_rank()) - f(*args, **kwargs) - import gc - - gc.collect() - gc.collect() - gc.collect() - mem_after = torch.cuda.memory_allocated(device=smp.local_rank()) - print( - f"rank is {smp.local_rank()}, function name is {f.__name__}, memory usage is {humanize.naturalsize(mem_after - mem_before)}" - ) - - return wrapper - - def init_master_params(self): - if self.use_smp: - torch.cuda.set_device(smp.local_rank()) - register_optimizer_hooks(self.model) - self.fp32paramid_from_fp16paramid = {} - # Tensor used to determine if a nan/if has happend. - # Any non-zero value indicates inf/nan. - # Note that we keep this for the cases that grad scaler is none. - # We still record nan/inf if we have a bfloat16 with a grad scaler. - if self.grad_scaler: - self.found_inf = torch.cuda.FloatTensor([0.0]) - - # Dummy tensor needed for apex multi-apply tensor. - # For bfloat, we don't have multi-tensor apply and for now - # we set it to none so the multi-tensor apply gets ignored. - if self.bf16: - self._dummy_overflow_buf = None - else: - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - - # In case grad scaler is not passed, define the unity scale. - if self.grad_scaler is None: - self._scale_one = torch.cuda.FloatTensor([1.0]) - - # ====================== - # main parameter stuff - # ====================== - - # only need to create contiguous buffer for fp16 params which require grads - contig_buffer_size = 0 - for param_group in self.optimizer.param_groups: - for param in param_group["params"]: - if param.requires_grad and param.type() in [ - "torch.cuda.HalfTensor", - "torch.cuda.BFloat16Tensor", - ]: - contig_buffer_size += param.numel() - - self.fp32_param_buffer = torch.empty( - contig_buffer_size, - device=torch.device("cuda", smp.local_rank()), - dtype=torch.float32, - requires_grad=True, - ) - offset = 0 - - # only need to create contiguous buffer for fp16 params which require grads - contig_buffer_size = 0 - for param_group in self.optimizer.param_groups: - for param in param_group["params"]: - if param.requires_grad and param.type() in [ - "torch.cuda.HalfTensor", - "torch.cuda.BFloat16Tensor", - ]: - contig_buffer_size += param.numel() - - self.fp32_param_buffer = torch.empty( - contig_buffer_size, - device=torch.device("cuda", smp.local_rank()), - dtype=torch.float32, - requires_grad=True, - ) - offset = 0 - - # For all the groups in the original optimizer: - for param_group in self.optimizer.param_groups: - float16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_float16_params_this_group = [] - fp32_from_fp16_paramids_this_group = [] - # For all the parameters in this group: - for i, param in enumerate(param_group["params"]): - if param.requires_grad: - # float16 params: - if param.type() in ["torch.cuda.HalfTensor", "torch.cuda.BFloat16Tensor"]: - float16_params_this_group.append(param) - # Create a copy - with torch.no_grad(): - master_param_buffer = self.fp32_param_buffer.narrow( - 0, offset, param.numel() - ).view_as(param) - master_param_buffer.copy_(param.float()) - offset += param.numel() - - main_param = nn.Parameter( - master_param_buffer, requires_grad=param.requires_grad - ) - self.master_is_distributed[ - main_param - ] = self.model.is_distributed_parameter(param) - self.master_distribution_axis[id(main_param)] = get_distribution_axis(param) - fp32_from_fp16_paramids_this_group.append(id(main_param)) - if hasattr(param, "shared"): - main_param.shared = param.shared - - # Replace the optimizer params with the new fp32 copy. - param_group["params"][i] = main_param - fp32_from_float16_params_this_group.append(main_param) - # Reset existing state dict key to the new main param. - if param in self.optimizer.state: - self.optimizer.state[main_param] = self.optimizer.state.pop(param) - self.fp32paramid_from_fp16paramid[id(param)] = id(main_param) - - # fp32 params. - elif param.type() == "torch.cuda.FloatTensor": - fp32_params_this_group.append(param) - param_group["params"][i] = param - - else: - raise TypeError( - "Wrapped parameters must be one of " - "torch.cuda.FloatTensor, " - "torch.cuda.HalfTensor, or " - "torch.cuda.BFloat16Tensor. " - "Received {}".format(param.type()) - ) - - self.float16_groups.append(float16_params_this_group) - self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group) - self.fp32_from_fp16_paramid_groups.append(fp32_from_fp16_paramids_this_group) - self.fp32_from_fp32_groups.append(fp32_params_this_group) - - # Leverage state_dict() and load_state_dict() to - # recast preexisting per-param state tensors - self.optimizer.load_state_dict(self.optimizer.state_dict()) - self.master_params_created = True - - def maybe_print(self, msg): - if self.verbose: - print(msg) - - def zero_grad(self, set_grads_to_None=True): - """We only need to zero the model related parameters, i.e., - float16_groups & fp32_from_fp32_groups.""" - - if self.shard_optimizer_state and set_grads_to_None and not self.warned_set_grads_to_none: - print( - "WARNING: Will not set fp16 gradients to None since shard_optimizer_state is enabled." - ) - self.warned_set_grads_to_none = True - - for group in self.float16_groups: - _zero_grad_group_helper(group, set_grads_to_None and not self.shard_optimizer_state) - for group in self.fp32_from_fp32_groups: - _zero_grad_group_helper(group, set_grads_to_None) - for group in self.optimizer.param_groups: - for p in group["params"]: - if set_grads_to_None: - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() - - def get_loss_scale(self): - if self.grad_scaler is None: - return self._scale_one - return self.grad_scaler.scale - - def _copy_model_grads_to_main_grads(self): - # This only needs to be done for the float16 group. - for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): - for model_param, main_param in zip(model_group, main_group): - if model_param.grad is not None: - # If gradient_as_bucket_view is True for DistributedModel, the grads will be in FP32 - # thus below line wont create a copy of grads - # Otherwise below line will create a copy of grads - main_param.grad = model_param.grad.float() - - def _unscale_main_grads_and_check_for_nan(self): - main_grads = [] - # fp32 params fromm float16 ones. - for main_group in self.fp32_from_float16_groups: - for main_param in main_group: - if main_param.grad is not None: - main_grads.append(main_param.grad.data) - # Append fp32 parameters. - for main_group in self.fp32_from_fp32_groups: - for main_param in main_group: - if main_param.grad is not None: - main_grads.append(main_param.grad.data) - # Reset found inf. - self.found_inf.fill_(0.0) - # Unscale and set found inf/nan - if hasattr(torch, "_amp_foreach_non_finite_check_and_unscale_"): - torch._amp_foreach_non_finite_check_and_unscale_( - main_grads, self.found_inf, self.grad_scaler.inv_scale - ) - else: - if self.grad_scaler.inv_scale != 1.0: - grads = [main_grad for main_grad in main_grads if main_grad is not None] - _overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier( - amp_C.multi_tensor_scale, - _overflow_buf, - [grads, grads], - self.grad_scaler.inv_scale, - ) - self.found_inf[0] = _overflow_buf[0] - - # Update across all model parallel instances. - """ - torch.distributed.all_reduce(self.found_inf, - op=torch.distributed.ReduceOp.MAX, - group=mpu.get_model_parallel_group()) - """ - torch.distributed.all_reduce( - self.found_inf, op=torch.distributed.ReduceOp.MAX, group=smp.get_mp_process_group() - ) - - # Check for nan. - found_inf_flag = self.found_inf.item() > 0 - return found_inf_flag - - def _get_model_and_main_params_data_float16(self): - model_data = [] - main_data = [] - for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): - for model_param, main_param in zip(model_group, main_group): - model_data.append(model_param.data) - main_data.append(main_param.data) - return model_data, main_data - - def _copy_main_params_to_model_params(self): - # Only needed for the float16 params. - model_data, main_data = self._get_model_and_main_params_data_float16() - _multi_tensor_copy_this_to_that( - this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf - ) - - def _copy_model_params_to_main_params(self): - # Only needed for the float16 params. - model_data, main_data = self._get_model_and_main_params_data_float16() - _multi_tensor_copy_this_to_that( - this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf - ) - - def reload_model_params(self): - self._copy_model_params_to_main_params() - - @torch.no_grad() - def step(self): - - # Copy gradients from model params to main params. - self._copy_model_grads_to_main_grads() - - # Do unscale, check for inf, and update grad scaler only for - # the case that grad scaler is provided. - if self.grad_scaler: - - # Unscale and check for inf/nan. - found_inf_flag = self._unscale_main_grads_and_check_for_nan() - - # We are done with scaling gradients - # so we can update the loss scale. - self.grad_scaler.update(found_inf_flag) - - # If we found inf/nan, skip the update. - if found_inf_flag: - print("Found, inf, skipping step") - return False, None, None - - # Clip the main gradients. - grad_norm = None - if self.clip_grad > 0.0: - grad_norm = self.clip_grad_norm(self.clip_grad) - - # count the zeros in the grads - num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None - - # Step the optimizer. - self.optimizer.step() - - # Update params from main params. - self._copy_main_params_to_model_params() - - # Successful update. - return True, grad_norm, num_zeros_in_grad - - def state_dict(self): - if not self.use_smp: - state_dict = {} - state_dict["optimizer"] = self.optimizer.state_dict() - if self.grad_scaler: - state_dict["grad_scaler"] = self.grad_scaler.state_dict() - state_dict["fp32_from_fp16_params"] = self.fp32_from_float16_groups - return state_dict - else: - return self.optimizer.state_dict() - - def load_state_dict(self, state_dict): - if not self.use_smp: - # Optimizer. - optimizer_key = "optimizer" - if optimizer_key not in state_dict: - optimizer_key = "optimizer_state_dict" - print("***WARNING*** loading optimizer from " "an old checkpoint ...") - self.optimizer.load_state_dict(state_dict[optimizer_key]) - - # Grad scaler. - if "grad_scaler" not in state_dict: - print("***WARNING*** found an old checkpoint, will not " "load grad scaler ...") - else: - if self.grad_scaler: - self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) - else: - print( - "***WARNING*** fould the grad scaler in the " - "checkpoint but it is None in the class. " - "Skipping loading grad scaler ..." - ) - - # Copy data for the main params. - fp32_from_float16_params_key = "fp32_from_fp16_params" - if fp32_from_float16_params_key not in state_dict: - fp32_from_float16_params_key = "fp32_from_fp16" - for current_group, saved_group in zip( - self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key] - ): - for current_param, saved_param in zip(current_group, saved_group): - current_param.data.copy_(saved_param.data) - else: - self.optimizer.load_state_dict(state_dict) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/grad_scaler.py b/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/grad_scaler.py deleted file mode 100644 index add28ebddb..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/fp16/megatron/grad_scaler.py +++ /dev/null @@ -1,130 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Megatron grad scaler.""" - -from abc import ABC -from abc import abstractmethod - -import torch - - -class MegatronGradScaler(ABC): - def __init__(self, initial_scale): - """Initialize scale value with the input initial scale.""" - assert initial_scale > 0.0 - self._scale = torch.cuda.FloatTensor([initial_scale]) - - @property - def scale(self): - return self._scale - - @property - def inv_scale(self): - return self._scale.double().reciprocal().float() - - @abstractmethod - def update(self, found_inf): - pass - - @abstractmethod - def state_dict(self): - pass - - @abstractmethod - def load_state_dict(self, state_dict): - pass - - -class ConstantGradScaler(MegatronGradScaler): - def update(self, found_inf): - pass - - def state_dict(self): - return dict() - - def load_state_dict(self, state_dict): - pass - - -class DynamicGradScaler(MegatronGradScaler): - def __init__( - self, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis - ): - """ "Grad scaler with dynamic scale that gets adjusted - during training.""" - super(DynamicGradScaler, self).__init__(initial_scale) - - # Lower bound on the scale. - assert min_scale > 0.0 - assert min_scale <= initial_scale - self.min_scale = torch.cuda.FloatTensor([min_scale]) - # Growth and backoff factors for the scale. - assert growth_factor > 1.0 - self.growth_factor = torch.cuda.FloatTensor([growth_factor]) - self.cur_iter = 0 - self.last_overflow_iter = -1 - assert backoff_factor < 1.0 - assert backoff_factor > 0.0 - self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) - # Interval over which if we don't see any inf/nan, - # we will scale the grad scale by the growth factor. - assert growth_interval > 0 - self.growth_interval = growth_interval - # Number of inf/nans we should see before scaling down - # the grad scale by the backoff factor. - assert hysteresis > 0 - self.hysteresis = hysteresis - - # Trackers. - self._growth_tracker = 0 - self._hysteresis_tracker = self.hysteresis - - def update(self, found_inf): - - # If we have an inf/nan, growth tracker is set to 0 - # and hysterisis tracker is reduced by 1. - if found_inf: - self._growth_tracker = 0 - self._hysteresis_tracker -= 1 - # Now if we are out of hysteresis count, scale down the loss. - if self._hysteresis_tracker <= 0: - self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale) - self.last_overflow_iter = self.cur_iter - else: - # If there is no nan/inf, increment the growth tracker. - self._growth_tracker += 1 - # If we have had enough consequitive intervals with no nan/inf: - if self._growth_tracker == self.growth_interval: - # if (self.cur_iter - self.last_overflow_iter) % self.growth_interval == 0: - # Reset the tracker and hysteresis trackers, - self._growth_tracker = 0 - self._hysteresis_tracker = self.hysteresis - # and scale up the loss scale. - self._scale = self._scale * self.growth_factor - self.cur_iter += 1 - - def state_dict(self): - state_dict = {} - state_dict["scale"] = self._scale - state_dict["growth_tracker"] = self._growth_tracker - state_dict["hysteresis_tracker"] = self._hysteresis_tracker - return state_dict - - def load_state_dict(self, state_dict): - self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) - self._growth_tracker = state_dict["growth_tracker"] - self._hysteresis_tracker = state_dict["hysteresis_tracker"] diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py b/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py index 2b0236b836..329926a26e 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py @@ -1,6 +1,29 @@ +import psutil +import os + import smdistributed.modelparallel.torch as smp import torch +try: + from py3nvml import py3nvml +except ImportError: + py3nvml = None + +dtype_to_bit = { +torch.float32 : 32, +torch.float64 : 64, +torch.float16: 16, +torch.bfloat16: 16, +torch.uint8: 8, +torch.int8: 8, +torch.int16: 16, +torch.int32: 32, +torch.int64: 64, +torch.bool: 1 +} +process = psutil.Process(os.getpid()) +base_mem_usage = process.memory_info().data +last_mem_usage = base_mem_usage def memory_status(msg="", reset_max=True, sync=True): @@ -16,6 +39,15 @@ def memory_status(msg="", reset_max=True, sync=True): if rdp_rank != 0: return + if py3nvml != None: + py3nvml.nvmlInit() + handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank) + info = py3nvml.nvmlDeviceGetMemoryInfo(handle) + total_used = info.used / 1024**3 + total_used_str = f"Totally used GPU memory: {total_used}" + else: + total_used_str = "" + alloced = torch.cuda.memory_allocated(device=local_rank) max_alloced = torch.cuda.max_memory_allocated(device=local_rank) cached = torch.cuda.memory_reserved(device=local_rank) @@ -28,11 +60,52 @@ def memory_status(msg="", reset_max=True, sync=True): max_cached /= 1024**3 print( - f"[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}", - f"device={local_rank} " - f"alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} " - f"cache {cached:0.4f} max_cached {max_cached:0.4f}", + f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', + f'device={local_rank} ' + f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} ' + f'cache {cached:0.4f} max_cached {max_cached:0.4f} ' + f'{total_used_str}' ) if reset_max: torch.cuda.reset_max_memory_cached() torch.cuda.reset_max_memory_allocated() + if py3nvml != None: + py3nvml.nvmlShutdown() + +def memory_status_cpu(msg=""): + import gc + global last_mem_usage + global base_mem_usage + rdp_rank = smp.rdp_rank() + gc.collect() + gc.collect() + gc.collect() + objects = gc.get_objects() + tensors = [obj for obj in objects if isinstance(obj, torch.Tensor) and not obj.is_cuda] + torch_usage = 0 + for t in tensors: + torch_usage += t.numel() * dtype_to_bit[t.dtype] + #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes + current_usage = process.memory_info().data + total_usage = current_usage - base_mem_usage + usage_change = current_usage - last_mem_usage + last_mem_usage = current_usage + + torch_usage /= 1024**3 + total_usage /= 1024**3 + usage_change /= 1024**3 + base_usage = base_mem_usage / 1024**3 + + rank = smp.rank() + tp_rank = smp.tp_rank() + pp_rank = smp.pp_rank() + rdp_rank = smp.rdp_rank() + local_rank = smp.local_rank() + if rdp_rank != 0: + return + + print( + f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', + f'device={local_rank} ' + f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}' + ) \ No newline at end of file diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/requirements.txt b/training/distributed_training/pytorch/model_parallel/gpt-j/requirements.txt index 88b524b21f..ea34021370 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/requirements.txt +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/requirements.txt @@ -4,7 +4,7 @@ sagemaker sagemaker-experiments scipy torchnet -transformers +transformers==4.21.0 smdebug humanize smart-open>=5.2.1 diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py b/training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py new file mode 100644 index 0000000000..e9e7ebd79e --- /dev/null +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py @@ -0,0 +1,240 @@ +import torch +import glob +import math +import os +import re +import gc +from collections import OrderedDict + +# load to cpu +device = torch.device('cpu') +smp_prefix = "module." + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [ atoi(c) for c in re.split(r'(\d+)', text) ] + +def get_model_state_file(checkpoint_dir): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + file = os.path.join(checkpoint_dir, "model_0.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + +def get_optim_files(checkpoint_dir): + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "optimizer_*.pt")), key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + +def get_user_content_file(checkpoint_dir): + file = os.path.join(checkpoint_dir, "user_content.pt") + if not os.path.exists(file): + raise FileNotFoundError(f"can't find user content file at '{file}'") + return file + +def parse_model_state(model_file, user_content_file, dtype): + state_dict = torch.load(model_file, map_location=device) + user_content = torch.load(user_content_file, map_location=device) + + if "buffer_names" not in user_content: + raise ValueError(f"{user_content_file} miss buffer_names to reconstruct the full state") + if "param_shapes" not in user_content: + raise ValueError(f"{user_content_file} miss param_shapes to reconstruct the full state") + buffer_names = user_content["buffer_names"] + param_shapes = user_content["param_shapes"] + + # recover just the buffers while restoring them to the specified dtype + buffers = { + k: v.to(dtype) + for k, + v in state_dict["module"].items() if k in buffer_names + } + + return buffers, param_shapes + +def parse_optim_states(files, checkpoint_dir, dtype): + total_files = len(files) + state_dicts = [] + sharded_data_parallel_size = None + # param_shapes = None + fp32_groups_key = None + for i, f in enumerate(files): + states = torch.load(f, map_location=device) + if i == 0: + sharded_data_parallel_size = states["partition_count"] + states["fp32_flat_groups"] = [group.to(dtype) for group in states["fp32_flat_groups"]] + state_dicts.append(states["fp32_flat_groups"]) + + if type(sharded_data_parallel_size) is list: + sharded_data_parallel_size = max(sharded_data_parallel_size) + + if sharded_data_parallel_size != total_files: + raise ValueError( + f"Expected {sharded_data_parallel_size} of 'optimizer_*.pt' under '{checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + flat_groups = [ + torch.cat(state_dicts[i], + 0) for i in range(len(state_dicts)) + ] + + return sharded_data_parallel_size, flat_groups + +def partitioned_param_info(unpartitioned_numel, sharded_data_parallel_size): + remainder = unpartitioned_numel % sharded_data_parallel_size + padding_numel = (sharded_data_parallel_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / sharded_data_parallel_size) + return partitioned_numel, padding_numel + +def get_full_state_dict_from_sharded_data_parallel_checkpoint(checkpoint_dir, dtype=torch.float32, tag=None, remove_smp_prefix=True): + """ + Returns full state_dict reconstructed from sharded data parallel checkpoint + + Args: + - checkpoint_dir: path to the sharded data parallel checkpoint folder (where the optimizer files are) + - dtype: the dtype of the output full checkpoint + - tag: the checkpoint tag, if not specified will read the newest checkpoint + - remove_smp_prefix: remove the "module." prefix created by smp + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'newest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'newest' file at {latest_path}") + + checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + print(f"Processing checkpoint '{checkpoint_dir}'") + + optim_files = get_optim_files(checkpoint_dir) + sharded_data_parallel_size, flat_groups = parse_optim_states(optim_files, checkpoint_dir, dtype) + + model_file = get_model_state_file(checkpoint_dir) + user_content_file = get_user_content_file(checkpoint_dir) + buffers, param_shapes = parse_model_state(model_file, user_content_file, dtype) + + gc.collect() + avail_numel = flat_groups[0].numel() * sharded_data_parallel_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + # params + offset = 0 + total_numel = 0 + total_params = 0 + + state_dict = OrderedDict() + state_dict.update(buffers) + + for name, shape in param_shapes.items(): + if remove_smp_prefix and name.startswith(smp_prefix): + name = name[len(smp_prefix):] + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = partitioned_param_info(unpartitioned_numel, sharded_data_parallel_size) + + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # memory usage doubles here + state_dict[name] = torch.cat( + tuple(flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(sharded_data_parallel_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= sharded_data_parallel_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + +def get_param_shapes(model, optimizer): + """Returns a dict of name to shape mapping, only for the flattened weights saved by the + optimizer. the names are exactly as in state_dict. The order is absolutely important, since + the saved data is just flattened data with no identifiers and requires reconstruction in the + same order it was saved. + + We can't rely on module.named_parameters() to get the saved tensors, as some params + will be missing and others unsaved and then it'd be impossible to reconstruct state_dict + from the flattened weights. + """ + param_group_shapes = [] + cnt = 0 + numel = 0 + + bit16_groups = optimizer.fp16_groups + param_names = {param: name for name, param in model.module.named_parameters()} + + for bit16_group in bit16_groups: + param_shapes = OrderedDict() + for param in bit16_group: + cnt += 1 + numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() + shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape + if param not in param_names: + raise ValueError(f"failed to find optimizer param in named params") + name = param_names[param] + param_shapes[name] = shape + + param_group_shapes.append(param_shapes) + + return param_group_shapes + +def get_buffer_names(model): + buffer_names = [] + + # we save buffer names so that we could extract later the real buffers from the saved + # state_dict["module"] in the non-zero checkpoint - the buffers are already there but they + # are intermixed with param placeholders + + # have to traverse the tree to be able to skip non-persistent buffers + def get_layer_named_buffers(module, prefix=""): + for name, buf in module.named_buffers(recurse=False): + if buf is not None and name not in module._non_persistent_buffers_set: + buffer_names.append(prefix + name) + + for name, child in module.named_children(): + if child is not None: + get_layer_named_buffers(child, prefix + name + ".") + + get_layer_named_buffers(model.module, prefix="") + + return buffer_names \ No newline at end of file diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py index 7b8c4f07fe..6fdd64a892 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py @@ -38,7 +38,6 @@ from preprocess import Preprocess from smp_trainer import SMPTrainer -from fp16 import FP16_Module, FP16_Optimizer, load_fp16_optimizer, save_fp16_optimizer from learning_rates import AnnealingLR from smdistributed.modelparallel.torch.nn import FusedLayerNorm as LayerNorm @@ -100,7 +99,6 @@ def save( seq_length=1024, batch_idx=0, ): - save_fn = save_fp16_optimizer_megatron if args.megatron else save_fp16_optimizer save_dict = { "cli_args": args.__dict__, "num_params": num_params, @@ -125,25 +123,17 @@ def save( else model_state_dict ) - if args.fp16: - if not partial and args.skip_full_optimizer: + if partial: + save_dict["optimizer"] = optimizer.local_state_dict(gather_if_shard=args.gather_if_shard) + else: + if args.skip_full_optimizer: print("Skipping saving the final optimizer state") - else: - if args.shard_optimizer_state == 0 or partial: - save_dict["optimizer"] = save_fn(args, model, optimizer, partial=partial) - else: - print( + elif args.shard_optimizer_state > 0: + print( "Saving the full optimizer state does not work with shard_optimizer_state > 0! Skipping..." - ) - else: - # fp32 - if partial: - save_dict["optimizer"] = optimizer.local_state_dict() + ) else: - if not args.skip_full_optimizer: - save_dict["optimizer"] = optimizer.state_dict() - else: - print("Skipping saving of full optimizer state") + save_dict["optimizer"] = optimizer.state_dict() if not args.gather_if_shard or (smp.rdp_rank() == 0 and partial) or smp.rank() == 0: smp.save(save_dict, output_save_file, partial=partial, v3=not args.gather_if_shard) @@ -316,10 +306,7 @@ def main(): torch.set_default_dtype(torch.float32) - iter_model = model - # Build parameter groups (weight decay and non-decay). - while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): - iter_model = iter_model.module + iter_model = model.get_module() param_groups = get_param_groups_by_weight_decay(iter_model) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py index 49361cf6f7..00e03c7f63 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py @@ -15,16 +15,12 @@ import torch.utils.data import transformers from data_pipeline import create_pretraining_dataloader -from fp16 import FP16_Module, FP16_Optimizer, load_fp16_optimizer, save_fp16_optimizer -from fp16.megatron.fp16 import Float16OptimizerWithFloat16Params -from fp16.megatron.fp16 import load_fp16_optimizer as load_fp16_optimizer_megatron -from fp16.megatron.fp16 import save_fp16_optimizer as save_fp16_optimizer_megatron -from fp16.megatron.grad_scaler import DynamicGradScaler from learning_rates import AnnealingLR -from memory_tracker import memory_status +from memory_tracker import memory_status, memory_status_cpu +from sharded_data_parallel_checkpoint import get_buffer_names, get_param_shapes from smdistributed.modelparallel.torch.nn import FusedLayerNorm as LayerNorm from smdistributed.modelparallel.torch.nn.huggingface.gptj import ( - translate_hf_gptj_state_dict_to_smdistributed, + translate_hf_state_dict_to_smdistributed_gptj, translate_state_dict_to_hf_gptj, ) from torch import optim @@ -41,6 +37,7 @@ ) from transformers.trainer_utils import is_main_process +logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.ERROR) logger = logging.getLogger(__name__) @@ -101,16 +98,12 @@ def train_step(model, optimizer, input_ids, attention_mask, args): loss = output["loss"] else: loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"] - if args.fp16: - if args.megatron: - scaled_loss = optimizer.scale_loss(loss) - model.backward(scaled_loss) - else: - optimizer.backward(loss, update_master_grads=False) - else: - model.backward(loss) + + model.backward(loss) + if args.logits_output: return output + return loss @@ -121,289 +114,6 @@ def test_step(model, input_ids, attention_mask): return loss -def save_ckptsum(args, model, optimizer, filename): - results = collections.defaultdict(dict) - model_result = collections.defaultdict(dict) - - if args.fp16: - from fp16.fp16util import register_optimizer_hooks - - register_optimizer_hooks(model) - - def _get_optimizer_result(optimizer_states): - _optimizer_result = collections.defaultdict(dict) - for param_idx, state in optimizer_states.items(): - for key, val in state.items(): - if isinstance(val, torch.Tensor): - _optimizer_result["tensors"][f"{param_idx}_{key}"] = torch.sum(val) - else: - _optimizer_result["scalars"][f"{param_idx}_{key}"] = val - return _optimizer_result - - if not args.shard_optimizer_state: - optimizer_result = _get_optimizer_result(optimizer.local_state_dict()["state"]) - else: - local_state_dict = optimizer.local_state_dict()["state"] - if smp.rdp_rank() == 0: - optimizer_result = [] - for partial_local_state_dict in local_state_dict: - optimizer_result.append(_get_optimizer_result(partial_local_state_dict)) - - for param_name, param in model.local_state_dict().items(): - if isinstance(param, torch.Tensor): - model_result["tensors"][param_name] = torch.sum(param) - else: - model_result["scalars"][param_name] = param - - if smp.rdp_rank() == 0: - results["optimizer"] = optimizer_result - results["model"] = model_result - smp.save(results, filename) - - -def load_and_verify_ckptsum(args, model, optimizer, filename): - results = smp.load(filename) - optimizer_result = ( - results["optimizer"] - if not args.shard_optimizer_state - else results["optimizer"][smp.rdp_rank()] - ) - model_result = results["model"] - - def opt_check_fn(mod, opt): - loaded_opt_states = ( - opt.orig_state_dict()["state"] - if args.shard_optimizer_state - else opt.local_state_dict()["state"] - ) - for param_idx, state in loaded_opt_states.items(): - for key, val in state.items(): - if isinstance(val, torch.Tensor): - assert torch.isclose( - torch.sum(val), optimizer_result["tensors"][f"{param_idx}_{key}"] - ), f"mismatch for param_idx: {param_idx}, key is {key}" - else: - assert ( - val == optimizer_result["scalars"][f"{param_idx}_{key}"] - ), f"mismatch for param_idx: {param_idx}, key is {key}" - print("Optimizer save/load check passed successfully") - - def model_check_fn(mod, opt): - for param_name, param in mod.local_state_dict().items(): - if isinstance(param, torch.Tensor): - assert torch.isclose( - torch.sum(param), model_result["tensors"][param_name] - ), f"mismatch for param_name: {param_name}" - else: - assert ( - param == model_result["scalars"][param_name] - ), f"mismatch for param_name: {param_name}" - print("Model save/load check passed successfully") - - model.register_post_partition_hook(model_check_fn) - model.register_post_step_hook(opt_check_fn) - - -def save( - output_save_file, - model, - optimizer, - lr_scheduler, - model_config, - num_params, - total_steps, - curr_train_path_index, - args, - partial=True, - translate_to_hf=False, - seq_length=1024, - batch_idx=0, -): - save_fn = save_fp16_optimizer_megatron if args.megatron else save_fp16_optimizer - save_dict = { - "cli_args": args.__dict__, - "num_params": num_params, - "total_steps": total_steps, - "curr_train_path_index": curr_train_path_index, - "model_config": model_config, - "batch_idx": batch_idx, - } - - if lr_scheduler is not None: - save_dict["lr_scheduler"] = lr_scheduler.state_dict() - if partial: - if args.gather_if_shard > 0 or smp.rdp_rank() == 0: - # if not gather the opt checkpoint, only save the model for rdp rank 0 - save_dict["model"] = model.local_state_dict() - else: - model_state_dict = model.state_dict(gather_to_rank0=True) - if smp.rank() == 0: - save_dict["model"] = ( - translate_state_dict_to_hf_gptj(model_state_dict, seq_length) - if translate_to_hf - else model_state_dict - ) - - if args.fp16: - if not partial and args.skip_full_optimizer: - print("Skipping saving the final optimizer state") - else: - if args.shard_optimizer_state == 0 or partial: - save_dict["optimizer"] = save_fn(args, model, optimizer, partial=partial) - else: - print( - "Saving the full optimizer state does not work with shard_optimizer_state > 0! Skipping..." - ) - else: - # fp32 - if partial: - save_dict["optimizer"] = optimizer.local_state_dict() - else: - if not args.skip_full_optimizer: - save_dict["optimizer"] = optimizer.state_dict() - else: - print("Skipping saving of full optimizer state") - - if not args.gather_if_shard or (smp.rdp_rank() == 0 and partial) or smp.rank() == 0: - smp.save(save_dict, output_save_file, partial=partial, v3=not args.gather_if_shard) - - print(f"Finished checkpointing after {total_steps} steps: {output_save_file}") - - -def load_model_and_optimizer( - output_dir, - model, - optimizer, - lr_scheduler, - partial, - args, - translate_from_hf=False, - seq_length=1024, - load_model=True, - load_optimizer=True, - num_params=0, -): - # Find longest-trained checkpoint - re_pattern = f"trained_gpt_nparams-{num_params}_steps-(?P\d+)\.pt" - if partial: - re_pattern += "_(?P\d+)" - else: - re_pattern += "$" - - ckpt_paths = sorted( - [ - (int(re.match(re_pattern, p).group("total_steps")), os.path.join(output_dir, p)) - for p in os.listdir(output_dir) - if re.match(re_pattern, p) - ], - reverse=True, - ) - if not ckpt_paths: - raise Exception( - f'No checkpoints could be found in "{output_dir}". Candidates: {os.listdir(output_dir)}' - ) - - local_ckpt_path = ckpt_paths[0][1] - - if partial: - # need to pass prefix without ranks to smp - local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt" - - if args.gather_if_shard > 0: - # Should expect v2 checkpoint here - checkpoint = smp.load(local_ckpt_path, partial=partial) - else: - # Loading separately for model and opt - checkpoint = torch.load(f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_0") - if smp.rdp_rank() != 0: - opt_checkpoint = torch.load( - f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_{smp.rdp_rank()}" - ) - - if load_model: - checkpointed_model = ( - translate_hf_gptj_state_dict_to_smdistributed(checkpoint["model"], seq_length) - if translate_from_hf - else checkpoint["model"] - ) - model.load_state_dict(checkpointed_model, same_partition_load=args.same_partition_load > 0) - if lr_scheduler is not None: - lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - - if load_optimizer: - checkpoint = ( - checkpoint if args.gather_if_shard > 0 or smp.rdp_rank() == 0 else opt_checkpoint - ) - # Loading loss scale eagerly - if not args.megatron: - opt_state_dict = checkpoint["optimizer"] - optimizer.loss_scaler = opt_state_dict["loss_scaler"] - optimizer.loss_scaler.model = model - optimizer.dynamic_loss_scale = opt_state_dict["dynamic_loss_scale"] - optimizer.overflow = opt_state_dict["overflow"] - optimizer.first_closure_call_this_step = opt_state_dict["first_closure_call_this_step"] - - def opt_load_hook(mod, opt): - load_fn = load_fp16_optimizer_megatron if args.megatron else load_fp16_optimizer - if args.fp16: - if not partial and args.skip_full_optimizer: - print( - "Skipping loading the final optimizer state, and reloading master_params from model_params" - ) - opt.reload_model_params() - else: - load_fn(args, mod, opt, checkpoint, partial=partial) - else: - # fp32 - if not partial and args.skip_full_optimizer: - print("Skipping loading the final optimizer state") - else: - opt.load_state_dict(checkpoint["optimizer"]) - - model.register_post_step_hook(opt_load_hook) - - print(f'Loaded model from "{local_ckpt_path}"') - - batch_idx = 0 - if "batch_idx" in checkpoint: - batch_idx = checkpoint["batch_idx"] - - return ( - model, - optimizer, - checkpoint["total_steps"], - checkpoint["curr_train_path_index"], - batch_idx, - ) - - -def delete_oldest_ckpt(args, delete_on_rank0_only=False): - to_delete = smp.rank() == 0 if delete_on_rank0_only else smp.local_rank() == 0 - if to_delete: - re_pattern = "trained_gpt_nparams-(?P\d+)_steps-(?P\d+)\.pt" - - # partial - re_pattern += "_(?P\d+)_(?P\d+)" - - paths_per_step = collections.defaultdict(list) - - for p in os.listdir(args.checkpoint_dir): - if re.match(re_pattern, p): - step = int(re.match(re_pattern, p).group("total_steps")) - path = os.path.join(args.checkpoint_dir, p) - paths_per_step[step].append(path) - - if paths_per_step: - oldest_step = sorted(paths_per_step.keys())[0] - num_parts = len(paths_per_step[oldest_step]) - if len(paths_per_step) >= args.num_kept_checkpoints: - # delete oldest step to save the new one - for p in paths_per_step[oldest_step]: - os.remove(p) - # else We still haven't reached maximum number of checkpoints -- no need to delete older ones - return None - - def eval_model(model, dataloader, num_batches, use_bert_data): model = model.eval() n_batches = 0 @@ -445,6 +155,8 @@ def train( total_steps, args, ): + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="before train step") model.train() if args.parallel_proc_data_processing: pool = ProcessPoolExecutor(1) @@ -532,6 +244,9 @@ def train( to_save = {"loss": [], "val_loss": []} loss_metric = 0 + def grad_accumulation_boundary(batch_idx): + return batch_idx % args.gradient_accumulation == args.gradient_accumulation - 1 + def should_record(): # only record the ranks that in the tp group that contains global rank 0 if smp.tp_size() > 1: @@ -595,10 +310,8 @@ def should_record(): step_start = time.time() - if args.fp16: - optimizer.zero_grad(set_grads_to_None=True) - else: - optimizer.zero_grad() + if grad_accumulation_boundary(batch_idx - 1): + optimizer.zero_grad(set_to_none=True) if args.logits_output: train_output = train_step(model, optimizer, input_ids, attention_mask, args) @@ -616,40 +329,39 @@ def should_record(): loss = loss_mb.reduce_mean() if not args.validation_freq: loss_metric = loss.item() - + if args.enable_memory_profiling > 0: + memory_status_cpu("After_train_step_cpu") memory_status(msg="After_train_step") if args.clean_cache > 0: # empty the cache to avoid OOM torch.cuda.empty_cache() - if args.fp16: - if args.megatron: - success, _, _ = optimizer.step() - overflow = not success - else: - optimizer.update_master_grads() + + if grad_accumulation_boundary(batch_idx): + if args.fp16: optimizer.clip_master_grads(args.grad_clip) - optimizer.step() - overflow = optimizer.overflow - else: + optimizer.step() + if not (args.fp16 and optimizer.overflow): + lr_scheduler.step() - if not (args.fp16 and overflow): - lr_scheduler.step() - - if args.enable_memory_profiling > 0: - memory_status(msg="After_opt_step") + if args.enable_memory_profiling > 0: + memory_status(msg="After_opt_step") total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start sample_processed = input_ids.shape[0] * dp_size throughput = sample_processed / step_time + tokens_per_gpu = input_ids.shape[0] * input_ids.shape[1] + + # Based on the formula in https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ + tflops_per_gpu = 8 * num_params * tokens_per_gpu / step_time / 1e12 if smp.rank() == 0 and not total_steps % args.logging_freq: print( - f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec" + f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec, TFLOPS/GPU: {tflops_per_gpu}" ) # evaluate on validation @@ -675,37 +387,26 @@ def should_record(): # checkpoint if not (total_steps % args.checkpoint_freq): - base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" - out_path = os.path.join(args.checkpoint_dir, base_path) - total_ckpts = total_steps // args.checkpoint_freq - - delete_oldest_ckpt(args, delete_on_rank0_only=args.use_fsx > 0) - - # save_or_verify_ckptsum if this is the last checkpoint - if (args.save_or_verify_ckptsum and total_steps >= args.max_steps) or ( - (total_ckpts + 1) * args.checkpoint_freq - ) > args.max_steps: - # Save optimizer and model tensor sums and scalars before saving - save_ckptsum( - args, - model, - optimizer, - filename=os.path.join(args.model_dir, "saved_partial_sum"), - ) - - save( - out_path, - model, - optimizer, - lr_scheduler, - model_config, - num_params, - total_steps, - curr_train_path_index, - args, + user_content = { + "cli_args": args.__dict__, + "num_params": num_params, + "total_steps": total_steps, + "start_train_path_index": curr_train_path_index, + "model_config": model_config, + "start_batch_index": batch_idx+1, + } + # to reconstruct the full model + if args.sharded_data_parallel_degree > 1: + user_content["buffer_names"] = get_buffer_names(model) + user_content["param_shapes"] = get_param_shapes(model, optimizer) + user_content["lr_scheduler"] = lr_scheduler.state_dict() + smp.save_checkpoint(args.checkpoint_dir, + tag=f"total_steps{total_steps}", partial=True, - batch_idx=batch_idx + 1, - ) + model=model, + optimizer=optimizer, + user_content=user_content, + num_kept_partial_checkpoints=args.num_kept_checkpoints) if args.logits_output: to_save["loss"].append(loss.item()) @@ -766,10 +467,8 @@ def parse_args(): opt_grp.add_argument("--same_seed", type=int, default=0) opt_grp.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) opt_grp.add_argument("--fp16", default=0, type=int, help="automatic mixed precision training") - opt_grp.add_argument( - "--fp32_grad_accumulation", default=0, type=int, help="Enable FP32 Grad accumulation" - ) - opt_grp.add_argument("--megatron", default=0, type=int, help="use megatron fp16 optimizer") + opt_grp.add_argument("--bf16", default=0, type=int, help="automatic mixed precision training") + opt_grp.add_argument("--sharded_data_parallel_degree", default=1, type=int) opt_grp.add_argument("--grad_clip", default=1.0, type=float, help="gradient clipping") opt_grp.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") opt_grp.add_argument( @@ -790,11 +489,10 @@ def parse_args(): # I/O io_grp = parser.add_argument_group(title="io", description="location for input and output") - io_grp.add_argument("--use_bert_data", type=int, default=0, help="use bert data for training") - # change to 0 original 1 - io_grp.add_argument("--zipped_data", type=int, default=0, help="input data is zipped files") + io_grp.add_argument("--use_bert_data", type=int, default=0, help="use wiki corpus data for training") + io_grp.add_argument("--zipped_data", type=int, default=1, help="input data is zipped files") io_grp.add_argument( - "--epochs", type=int, default=30, help="times of iterating over the training dataset" + "--epochs", type=int, default=3, help="times of iterating over the training dataset" ) io_grp.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) io_grp.add_argument( @@ -823,12 +521,6 @@ def parse_args(): default=0, help="Enabling this will save a combined model only at the end", ) - io_grp.add_argument( - "--skip_full_optimizer", - type=int, - default=1, - help="Disabling this will also save the full optimizer state", - ) io_grp.add_argument("--load_partial", type=int, default=0, help="Load from partial checkpoints") io_grp.add_argument("--load_full", type=int, default=0, help="Load from full checkpoints") io_grp.add_argument( @@ -850,9 +542,11 @@ def parse_args(): model_grp.add_argument("--attn_pdrop", type=float, default=0.1) model_grp.add_argument("--summary_first_pdrop", type=float, default=0.1) model_grp.add_argument("--use_adamw", type=int, default=0, help="Use adamw optimizer") + model_grp.add_argument("--use_distributed_transformer", type=int, default=1, help="Use distributed transformer") + model_grp.add_argument("--checkpoint_sublayers", type=int, default=0, help="Apply activation checkpointing to submodules of each transformer layer") smp_grp = parser.add_argument_group(title="smp", description="smp") - smp_grp.add_argument("--tensor_parallel_degree", type=int, default=8) + smp_grp.add_argument("--tensor_parallel_degree", type=int, default=1) smp_grp.add_argument("--pipeline_parallel_degree", type=int, default=1) smp_grp.add_argument("--microbatches", type=int, default=1) smp_grp.add_argument("--active_microbatches", type=int, default=None) @@ -870,7 +564,9 @@ def parse_args(): smp_grp.add_argument("--skip_tracing", type=int, default=0) smp_grp.add_argument("--query_key_layer_scaling", type=int, default=1) smp_grp.add_argument("--fused_softmax", type=int, default=1) + smp_grp.add_argument("--fused_dropout", type=int, default=0) smp_grp.add_argument("--fused_bias_gelu", type=int, default=1) + smp_grp.add_argument("--gradient_accumulation", type=int, default=1) parser.add_argument( "--num_kept_checkpoints", @@ -908,9 +604,6 @@ def parse_args(): default="", help="number of transformer layers assigned to each partition", ) - parser.add_argument( - "--match_weights", type=int, default=0, help="Get weights from the original model" - ) parser.add_argument( "--preserve_np_state", type=int, @@ -983,22 +676,25 @@ def parse_args(): ci_grp.add_argument("--time_to_train", type=int, help="time to train threshold") ci_grp.add_argument("--throughput", type=float, help="throughput threshold") ci_grp.add_argument("--loss", type=float, help="loss threshold") - ci_grp.add_argument( - "--save_or_verify_ckptsum", default=False, action="store_true", help="Whether to save sum" - ) - args, _ = parser.parse_known_args() return args +def compute_num_params(model): + num_params = 0 + seen = set() + for p in model.parameters(): + if p not in seen: + seen.add(p) + if hasattr(p, "ds_shape"): + num_params += np.prod(p.ds_shape) + else: + num_params += np.prod(p.size()) + + return num_params def main(): args = parse_args() - if args.shard_optimizer_state > 0 and not args.skip_full_optimizer: - raise ValueError( - "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding." - ) - if args.partition_assignment != "" and args.manual_partition == 0: print("[Warning] partition_assignment is set, enable manual_partition") args.manual_partition = 1 @@ -1009,22 +705,21 @@ def main(): "tensor_parallel_degree": args.tensor_parallel_degree, "pipeline_parallel_degree": args.pipeline_parallel_degree, "microbatches": args.microbatches, - # if activation_checkpointing true checkpoints transformer layers below - "checkpoint_attentions": False if args.activation_checkpointing else True, "shard_optimizer_state": args.shard_optimizer_state > 0, "prescaled_batch": args.prescaled_batch > 0, - "_match_weights": args.match_weights > 0, - "fp16_params": args.fp16 > 0, + "fp16": args.fp16 > 0, + "bf16": args.bf16 > 0, "offload_activations": args.offload_activations > 0, + "delayed_parameter_initialization": args.delayed_param > 0, "optimize": args.optimize, "placement_strategy": args.placement_strategy, "activation_loading_horizon": args.activation_loading_horizon, "skip_tracing": args.skip_tracing > 0, "auto_partition": False if args.manual_partition else True, "default_partition": 0, - "_fp32_grad_accumulation": args.fp32_grad_accumulation > 0, "static_mode": args.static_mode > 0, "fast_mode": args.fast_mode > 0, + "sharded_data_parallel_degree": args.sharded_data_parallel_degree, } if args.active_microbatches is not None: smp_config["active_microbatches"] = args.active_microbatches @@ -1048,11 +743,6 @@ def main(): len(partition_assignment) == smp.pp_size() ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}" - if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0): - for path in [args.model_dir, args.checkpoint_dir]: - if not os.path.exists(path): - os.makedirs(path, exist_ok=True) - model_config = GPTJConfig( vocab_size=args.vocab_size, n_positions=args.max_context_width, @@ -1071,43 +761,49 @@ def main(): summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=args.summary_first_pdrop, - use_cache=True, + # gradient_checkpointing=args.gradient_checkpointing > 0, + use_cache=False, bos_token_id=50256, eos_token_id=50256, return_dict=True, - rotary_dim=64, ) # the following improves start-up time by skipping proper initialization # of weights in the original model. this is not a problem because DistributedModel - # will override those weights anyway when tensor_parallel_degree > 1. - if smp.tp_size() > 1 and args.match_weights < 1: + # will override those weights anyway when we use distributed transformer. + if args.use_distributed_transformer > 0: from transformers.modeling_utils import PreTrainedModel PreTrainedModel.init_weights = lambda x: None set_seed(args.seed) - if args.fp16: - torch.set_default_dtype(torch.float16) - with smp.tensor_parallelism( - enabled=smp.tp_size() > 1, + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="before model creation") + + if args.fp16 and args.bf16: + raise ValueError("FP16 and BF16 cannot be simultaneously enabled.") + elif args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.get_default_dtype() + + with smp.model_creation( + tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, + dtype=dtype, attention_in_fp32=args.attention_in_fp32 > 0, - query_key_layer_scaling=args.query_key_layer_scaling > 0, + query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, + fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, - ): - with smp.delay_param_initialization( - enabled=(smp.tp_size() > 1 and args.match_weights < 1 and args.delayed_param > 0) ): model = AutoModelForCausalLM.from_config(model_config) + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="after model creation") - torch.set_default_dtype(torch.float32) - - if args.fp16: - model = FP16_Module(model) - - num_params = sum([np.prod(p.size()) for p in model.parameters()]) + num_params = compute_num_params(model) if smp.rank() == 0: print(f"# total parameters: {num_params}") @@ -1124,19 +820,17 @@ def main(): # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. - if args.fp16: - torch.set_default_dtype(torch.float16) - model = smp.DistributedModel(model, trace_device="gpu") - - if args.fp16: - m = model.module - else: - m = model + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="before dist model creation") + model = smp.DistributedModel(model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation) + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="after dist model creation") - if smp.tp_size() > 1: - transformer_layers = m.module.module.transformer.seq_layers + m = model.get_module() + if args.use_distributed_transformer > 0: + transformer_layers = m.transformer.seq_layers else: - transformer_layers = m.module.module.transformer.h + transformer_layers = m.transformer.h if args.manual_partition: print(f"Manual partition enabled") @@ -1151,9 +845,6 @@ def main(): div, rem = divmod(args.num_layers, smp.pp_size()) get_num_layers = lambda x: (div + 1 if x >= smp.pp_size() - rem else div) assignments = [] - # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18" - # Need further investigation - # for pp_rank in reversed(range(smp.pp_size())): for pp_rank in range(smp.pp_size()): nl = get_num_layers(pp_rank) print(f"{nl} layers assigned to partition {pp_rank}") @@ -1162,13 +853,7 @@ def main(): for i, c in enumerate(transformer_layers.children()): smp.set_partition(c, assignments[i]) - torch.set_default_dtype(torch.float32) - - iter_model = model - # Build parameter groups (weight decay and non-decay). - while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): - iter_model = iter_model.module - param_groups = get_param_groups_by_weight_decay(iter_model) + param_groups = get_param_groups_by_weight_decay(m) if args.use_adamw > 0: optimizer = optim.AdamW( @@ -1180,51 +865,29 @@ def main(): ) if args.activation_checkpointing: - kwargs = {} - if isinstance(transformer_layers, nn.Sequential): - kwargs["pack_args_as_tuple"] = True - kwargs["strategy"] = args.activation_strategy - smp.set_activation_checkpointing(transformer_layers, **kwargs) - - if args.fp16: - if args.megatron: - grad_scaler = DynamicGradScaler( - initial_scale=2**32, - min_scale=1, - growth_interval=1000, - growth_factor=2.0, - backoff_factor=0.5, - hysteresis=2, - ) - optimizer = Float16OptimizerWithFloat16Params( - model, - optimizer, - clip_grad=1.0, - log_num_zeros_in_grad=False, - params_have_main_grad=args.fp32_grad_accumulation > 0, - bf16=False, - grad_scaler=grad_scaler, - use_smp=True, - shard_optimizer_state=args.shard_optimizer_state > 0, - ) + if args.use_distributed_transformer or smp.tp_size() > 1: + if args.checkpoint_sublayers: + for c in transformer_layers.children(): + smp.set_activation_checkpointing(c.attention) + smp.set_activation_checkpointing(c.output) + else: + smp.set_activation_checkpointing(transformer_layers, strategy=args.activation_strategy) else: - optimizer = FP16_Optimizer( - model, - optimizer, - static_loss_scale=None, - dynamic_loss_scale=True, - use_smp=True, - dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, - params_have_main_grad=args.fp32_grad_accumulation > 0, - shard_optimizer_state=args.shard_optimizer_state > 0, - ) + for c in transformer_layers.children(): + if args.checkpoint_sublayers: + smp.set_activation_checkpointing(c.attn) + smp.set_activation_checkpointing(c.mlp) + else: + smp.set_activation_checkpointing(c) - optimizer = smp.DistributedOptimizer(optimizer) + optimizer = smp.DistributedOptimizer( + optimizer, + static_loss_scale=None, + dynamic_loss_scale=True, + dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, + ) lr_scheduler = get_learning_rate_scheduler(optimizer, args) - if args.fp16: - model.register_post_step_hook(lambda model, optimizer: optimizer.init_master_params()) - if args.enable_memory_profiling > 0: model.register_post_partition_hook( lambda model, optimizer: memory_status(msg="After_partition") @@ -1239,31 +902,13 @@ def main(): ) partial = not args.load_full path = args.checkpoint_dir if partial else args.model_dir - translate_from_hf = not partial - ( - model, - optimizer, - total_steps, - start_train_path_index, - start_batch_index, - ) = load_model_and_optimizer( - path, - model, - optimizer, - lr_scheduler, - partial, - args, - translate_from_hf=translate_from_hf, - seq_length=args.max_context_width, - load_model=True, - load_optimizer=args.load_partial > 0, - num_params=num_params, - ) - if args.save_or_verify_ckptsum: - filename = "saved_sum" if args.load_full else "saved_partial_sum" - load_and_verify_ckptsum( - args, model, optimizer, filename=os.path.join(args.model_dir, filename) - ) + tag = None if partial else "fullmodel.pt" + user_content = smp.resume_from_checkpoint(path, tag=tag, partial=partial) + total_steps = user_content["total_steps"] if partial else 0 + start_train_path_index = user_content.get("start_train_path_index", 0) + start_batch_index = user_content.get("start_batch_index", 0) + if "lr_scheduler" in user_content: + lr_scheduler.load_state_dict(user_content["lr_scheduler"]) else: total_steps = 0 start_train_path_index = 0 @@ -1282,7 +927,6 @@ def main(): args, ) time_to_train = time.time() - start - print("time to train: {}".format(time_to_train)) if args.ci: print(f"[SMP_METRIC]__GPTJ__Time_to_train__{time_to_train}") print(f"[SMP_METRIC]__GPTJ__samples/second__{throughput}") @@ -1295,28 +939,30 @@ def main(): if args.save_final_full_model: # saves full model at the end - - base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" - out_path = os.path.join(args.model_dir, base_path) - if args.save_or_verify_ckptsum: - # Save optimizer and model tensor sums and scalars before saving - save_ckptsum(args, model, optimizer, filename=os.path.join(args.model_dir, "saved_sum")) - - if smp.rdp_rank() == 0: - save( - out_path, - model, - optimizer, - lr_scheduler, - model_config, - num_params, - total_steps, - -1, - args, - partial=False, - translate_to_hf=smp.tp_size() > 1, - seq_length=args.max_context_width, - ) + user_content = { + "cli_args": args.__dict__, + "num_params": num_params, + "total_steps": total_steps, + "model_config": model_config, + } + if args.sharded_data_parallel_degree > 1: + # When sharded_data_parallel_degree > 1, saving full model is not supported, saving partial instead + # To get the full model, one can use the following API + # > from sharded_data_parallel_checkpoint import get_full_state_dict_from_sharded_data_parallel_checkpoint + # > full_model = get_full_state_dict_from_sharded_data_parallel_checkpoint(args.model_dir, tag=f"sharded_data_parallel_final_full_{num_params}", dtype=torch.float32) + # > if args.use_distributed_transformer > 0: # translate the state_dict to hf format if distributed transformer is used + # > full_model = smp.nn.huggingface.gpt2.translate_state_dict_to_hf_gpt2(full_model, max_seq_len=args.max_context_width) + # Note: the shared parameter will not be reflected so during loading you might need to load with strict=False + user_content["buffer_names"] = get_buffer_names(model) + user_content["param_shapes"] = get_param_shapes(model, optimizer) + smp.save_checkpoint(args.model_dir, + tag=f"sharded_data_parallel_final_full_{num_params}", + partial=True, + model=model, + optimizer=optimizer, + user_content=user_content) + else: + smp.save_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False, model=model, user_content=user_content) smp.barrier() if smp.rank() == 0: diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/__init__.py b/training/distributed_training/pytorch/model_parallel/gpt2/fp16/__init__.py deleted file mode 100644 index 99d09e08bc..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .fp16util import ( - BN_convert_float, - network_to_half, - prep_param_lists, - model_grads_to_master_grads, - master_params_to_model_params, - tofp16, - to_python_float, - convert_module, - convert_network, - FP16Model, -) - -from .fp16 import * -from .loss_scaler import * diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py b/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py deleted file mode 100755 index f5b7c3461f..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py +++ /dev/null @@ -1,1027 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Stable version of apex FP16 Optimizer""" -import copy - -import amp_C -import smdistributed.modelparallel.torch as smp -import torch -from apex.multi_tensor_apply import multi_tensor_applier -from smdistributed.modelparallel.torch.state_mod import state as smp_state -from smdistributed.modelparallel.torch.utils import get_distribution_axis -from torch import nn -from torch._six import inf -from torch.autograd import Variable -from torch.nn.parameter import Parameter - -from .fp16util import ( - get_pp_merged_fp32_from_fp16_param_groups, - get_tp_merged_fp32_from_fp16_param_groups, - master_params_to_model_params, - model_grads_to_master_grads, - model_params_to_master_params, - register_optimizer_hooks, -) -from .loss_scaler import DynamicLossScaler, LossScaler - -FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) -HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) - - -def load_fp16_optimizer_finetuning(model, optimizer, state_dict): - opt_state_dict = state_dict["optimizer"] - - def param_name_to_index(self): - param_id_to_index = self._param_id_to_index() - name_to_index = {} - for name, param in model.named_parameters(): - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - name_to_index[name] = param_id_to_index[param_id] - return name_to_index - - def _param_index_to_param_local(self): - param_id_to_index = self._param_id_to_index() - param_index_to_param = {} - - if not model: - return param_index_to_param - - for param in model.local_parameters(): - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - param_index_to_param[param_id_to_index[param_id]] = param - - return param_index_to_param - - def hook_fn(model, optimizer): - print(f"Inside hook_fn, loading for finetuning") - from functools import partial - - optimizer.param_name_to_index = partial(param_name_to_index, optimizer) - optimizer._param_index_to_param_local = partial(_param_index_to_param_local, optimizer) - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - - for current_group, saved_group in zip( - optimizer.fp32_from_fp16_groups, optimizer.fp32_from_fp16 - ): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - - model.register_post_partition_hook(hook_fn) - - -def _get_param_index_to_id(param_id_to_index_tp_group): - param_index_to_id_tp_group = [] - for param_id_to_index_map in param_id_to_index_tp_group: - param_index_to_id_map = {} - for param_id, param_index in param_id_to_index_map.items(): - param_index_to_id_map[param_index] = param_id - param_index_to_id_tp_group.append(param_index_to_id_map) - return param_index_to_id_tp_group - - -def save_fp16_optimizer(args, model, optimizer, partial=True): - optimizer_state_dict = {} - loss_scaler = optimizer.loss_scaler - _model = loss_scaler.model - loss_scaler.model = None - _loss_scaler = copy.deepcopy(loss_scaler) - loss_scaler.model = _model - optimizer_state_dict["loss_scaler"] = _loss_scaler - optimizer_state_dict["dynamic_loss_scale"] = optimizer.dynamic_loss_scale - optimizer_state_dict["overflow"] = optimizer.overflow - optimizer_state_dict["first_closure_call_this_step"] = optimizer.first_closure_call_this_step - cpu_fp32_from_fp16_groups = [ - [param.cpu() for param in group] for group in optimizer.fp32_from_fp16_groups - ] - if optimizer.master_params_created: - register_optimizer_hooks(model) - if partial: - optimizer_state_dict["optimizer_state_dict"] = optimizer.local_state_dict(gather_if_shard=args.gather_if_shard > 0) - if args.shard_optimizer_state and args.gather_if_shard > 0: - if smp.rdp_rank() == 0: - print("With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0") - gathered_cpu_fp32_from_fp16_groups = [cpu_fp32_from_fp16_groups] - for src in range(1, smp.rdp_size()): - gathered_cpu_fp32_from_fp16_groups.append(smp.recv_from(src, smp.RankType.RDP_RANK)) - optimizer_state_dict["fp32_from_fp16"] = gathered_cpu_fp32_from_fp16_groups - else: - smp.send(cpu_fp32_from_fp16_groups, 0, smp.RankType.RDP_RANK) - optimizer_state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups - else: - optimizer_state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups - if smp.pp_size() > 1: - print( - "WARNING: Ensure that partition decision doesnt change between runs (you can ensure this by setting use_times=False in smp config)." - "If you want to save and load with partition decision changing between runs, use full save and load instead." - ) - else: - optimizer_state_dict["optimizer_state_dict"] = optimizer.state_dict() - if smp.tp_size() > 1 and not args.shard_optimizer_state: - tp_merged_fp32_from_fp16_groups, param_name_groups = get_tp_merged_fp32_from_fp16_param_groups( - optimizer, cpu_fp32_from_fp16_groups - ) - pp_merged_fp32_from_fp16_groups, param_name_groups = get_pp_merged_fp32_from_fp16_param_groups( - optimizer, tp_merged_fp32_from_fp16_groups, param_name_groups - ) - else: - raise ValueError( - "Loading full optimizer state is not supported, when TP is not enabled or shard_optimizer_state is enabled" - ) - optimizer_state_dict["fp32_from_fp16"] = pp_merged_fp32_from_fp16_groups - optimizer_state_dict["param_name_groups"] = param_name_groups - return optimizer_state_dict - - -def load_fp16_optimizer(args, model, optimizer, state_dict, partial=True): - opt_state_dict = state_dict["optimizer"] - - if optimizer.master_params_created: - register_optimizer_hooks(model) - - def hook_fn(model, optimizer): - optimizer.load_state_dict(opt_state_dict["optimizer_state_dict"]) - if partial: - if args.shard_optimizer_state and args.gather_if_shard > 0: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"][smp.rdp_rank()] - else: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - - for current_group, saved_group in zip( - optimizer.fp32_from_fp16_groups, optimizer.fp32_from_fp16 - ): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - - else: - optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] - param_name_groups = opt_state_dict["param_name_groups"] - param_id_to_index = optimizer._param_id_to_index() - param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group - param_index_to_name = param_index_to_name_tp_group[smp.tp_rank()] - for group_idx, (current_group, saved_group) in enumerate( - zip(optimizer.fp32_from_fp16_groups, optimizer.fp32_from_fp16) - ): - for current in current_group: - param_id = id(current) - param_index = param_id_to_index[param_id] - param_name = param_index_to_name[param_index] - arr_index = param_name_groups[group_idx][param_name] - saved = saved_group[arr_index] - if optimizer.master_distribution_axis[param_id] is not None: - axis = optimizer.master_distribution_axis[param_id] - slice_size = saved.size(axis) // smp.tp_size() - saved = torch.narrow( - saved.data, axis, slice_size * smp.tp_rank(), slice_size - ).contiguous() - else: - saved = saved.data - current.data.copy_(saved) - - model.register_post_partition_hook(hook_fn) - - -def clip_grad_norm_fp32(parameters, param_is_distributed, shard_optimizer_state, max_norm, norm_type=2): - """Clips gradient norm of an iterable of parameters whose gradients - are in fp32. - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - Returns: - Total norm of the parameters (viewed as a single vector). - """ - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - torch.cuda.set_device(smp.local_rank()) - grads = [] - grads_for_norm = [] - for param in parameters: - grad_not_none = param.grad is not None - is_not_shared = not hasattr(param, "shared") or not param.shared - is_not_tp_duplicate = smp.tp_rank() == 0 or ( - param in param_is_distributed and param_is_distributed[param] - ) - if grad_not_none: - grad = param.grad.detach() - # Make sure the grads are in fp32 - assert param.grad.type() == "torch.cuda.FloatTensor" - grads.append(grad) - if is_not_shared and is_not_tp_duplicate: - grads_for_norm.append(grad) - - # Norm parameters. - max_norm = float(max_norm) - norm_type = float(norm_type) - total_norm = torch.tensor(0.0, device=torch.device("cuda")) - - # Calculate norm. - if norm_type == inf: - if len(grads_for_norm) > 0: - total_norm = max(grad.abs().max() for grad in grads_for_norm) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - # Take max across all model-parallel GPUs. - # Reducing across all ranks since gradients may be different across data parallel ranks - # when optimizer state sharding is enabled. - group = smp.get_world_process_group() if shard_optimizer_state else smp.get_mp_process_group() - torch.distributed.all_reduce( - total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=group - ) - total_norm = total_norm_cuda[0].item() - - else: - if norm_type == 2.0: - dummy_overflow_buf = torch.cuda.IntTensor( - [0], device=torch.device("cuda", smp.local_rank()) - ) - # Use apex's multi-tensor applier for efficiency reasons. - # Multi-tensor applier takes a function and a list of list - # and performs the operation on that list all in one kernel. - if len(grads_for_norm) > 0: - grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False, # no per-parameter norm - ) - # Since we will be summing across data parallel groups, - # we need the pow(norm-type). - total_norm = grad_norm ** norm_type - - else: - for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm ** norm_type - - # Sum across all model-parallel GPUs. - group = smp.get_world_process_group() if shard_optimizer_state else smp.get_mp_process_group() - torch.distributed.all_reduce( - total_norm, op=torch.distributed.ReduceOp.SUM, group=group - ) - total_norm = total_norm.item() ** (1.0 / norm_type) - - # Scale. - if len(grads) > 0: - clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - dummy_overflow_buf = torch.cuda.IntTensor( - [0], device=torch.device("cuda", smp.local_rank()) - ) - multi_tensor_applier( - amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff - ) - - return total_norm - - -def conversion_helper(val, conversion): - """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" - if not isinstance(val, (tuple, list)): - return conversion(val) - rtn = [conversion_helper(v, conversion) for v in val] - if isinstance(val, tuple): - rtn = tuple(rtn) - return rtn - - -def fp32_to_fp16(val): - """Convert fp32 `val` to fp16""" - - def half_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, FLOAT_TYPES): - val = val.half() - return val - - return conversion_helper(val, half_conversion) - - -def fp16_to_fp32(val): - """Convert fp16 `val` to fp32""" - - def float_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, HALF_TYPES): - val = val.float() - return val - - return conversion_helper(val, float_conversion) - - -class FP16_Module(nn.Module): - def __init__(self, module): - super(FP16_Module, self).__init__() - self.add_module("module", module.half()) - - def forward(self, *inputs, **kwargs): - return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) - - def state_dict(self, destination=None, prefix="", keep_vars=False): - return self.module.state_dict(destination, prefix, keep_vars) - - def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): - return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars) - - def load_state_dict(self, state_dict, strict=True): - self.module.load_state_dict(state_dict, strict=strict) - - -class FP16_Optimizer(object): - """ - :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, - and manage static or dynamic loss scaling and master weights in a manner transparent to the user. - For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, - and changing the call to ``backward``. - - Example:: - - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - # Name the FP16_Optimizer instance to replace the existing optimizer - # (recommended but not required): - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - # loss.backward() becomes: - optimizer.backward(loss) - ... - - Example with dynamic loss scaling:: - - ... - optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) - # optional arg to control dynamic loss scaling behavior - # dynamic_loss_args={'scale_window' : 500}) - # Usually, dynamic_loss_args is not necessary. - - Args: - init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. - static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. - dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. - dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. - verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. - - ``init_optimizer`` is expected to have been constructed in the ordinary way. - It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be - named to replace ``init_optimizer``, for two reasons: - First, it means that references to the same name - later in the file will not have to change. - Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to - modify ``init_optimizer``. If you do choose a unique name for the new - :class:`FP16_Optimizer` instance, you should only work with this new instance, - because the preexisting optimizer might no longer behave as expected. - - ``init_optimizer`` may be any Pytorch optimizer. - It may contain a mixture of fp16 and fp32 parameters organized into any number of - ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will - ingest these ``param_groups`` and remember them. - - Calls to :: - - loss.backward() - - must be replaced with :: - - optimizer.backward(loss) - - because :class:`FP16_Optimizer` requires ownership of the backward pass to implement - loss scaling and copies to master gradients. - - .. note:: - Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients - are downscaled before being applied. This means that adjusting the loss scale, or using - dynamic loss scaling, should not require retuning the learning rate or any other - hyperparameters. - - - **Advanced options** - - **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. - See docstring for :attr:`step`. - - **Gradient clipping**: Use :attr:`clip_master_grads`. - - **Multiple losses**: If your model accumulates gradients from multiple losses, - this can be made more efficient by supplying ``update_master_grads=False`` - to :attr:`backward`. See docstring for :attr:`backward`. - - **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: - - print(optimizer.loss_scale) - optimizer.loss_scale = new_loss_scale - - For static loss scaling, manually adjusting the loss scale over time is a reasonable - thing to do. During later epochs, gradients may become smaller, and a - higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss - scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting - the loss scale is not recommended. - - **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in - Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` - should still work as intended. - """ - - def __init__( - self, - model, - init_optimizer, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - use_smp=False, - verbose=False, - params_have_main_grad=False, - shard_optimizer_state=False, - ): - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - - self.verbose = verbose - self.model = model - - self.optimizer = init_optimizer - # init_state_dict sets up an alternative way to cast per-param state tensors. - # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. - # init_state_dict = init_optimizer.state_dict() - - self.fp16_groups = [] - self.fp32_from_fp16_groups = [] - self.fp32_from_fp32_groups = [] - self.fp32_from_fp16_paramid_groups = [] - self.static_loss_scale = static_loss_scale - self.dynamic_loss_scale = dynamic_loss_scale - self.dynamic_loss_args = dynamic_loss_args - self.use_smp = use_smp - self.master_params_created = False - self.shard_optimizer_state = shard_optimizer_state - self.warned_set_grads_to_none = False - if not self.use_smp: - self.init_master_params() - - self.master_is_distributed = {} - self.master_distribution_axis = {} - self.params_have_main_grad = params_have_main_grad - - if self.dynamic_loss_scale: - if self.dynamic_loss_args is not None: - self.dynamic_loss_args["use_smp"] = self.use_smp - self.loss_scaler = DynamicLossScaler(self.model, self.shard_optimizer_state, **self.dynamic_loss_args) - else: - self.loss_scaler = DynamicLossScaler(self.model, self.shard_optimizer_state, use_smp=self.use_smp) - else: - self.loss_scaler = LossScaler(self.model, self.shard_optimizer_state, self.static_loss_scale, use_smp=self.use_smp) - - - def init_master_params(self): - - if self.use_smp: - torch.cuda.set_device(smp.local_rank()) - register_optimizer_hooks(self.model) - self.fp32paramid_from_fp16paramid = {} - - # only need to create contiguous buffer for fp16 params which require grads - contig_buffer_size = 0 - for param_group in self.optimizer.param_groups: - for param in param_group["params"]: - if param.requires_grad and param.type() == "torch.cuda.HalfTensor": - contig_buffer_size += param.numel() - - self.fp32_param_buffer = torch.empty( - contig_buffer_size, - device=torch.device("cuda", smp.local_rank()), - dtype=torch.float32, - requires_grad=True, - ) - offset = 0 - for i, param_group in enumerate(self.optimizer.param_groups): - self.maybe_print("FP16_Optimizer processing param group {}:".format(i)) - fp16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_fp16_params_this_group = [] - fp32_from_fp16_paramids_this_group = [] - for i, param in enumerate(param_group["params"]): - if param.requires_grad: - if param.type() == "torch.cuda.HalfTensor": - self.maybe_print( - "FP16_Optimizer received torch.cuda.HalfTensor with {}".format( - param.size() - ) - ) - fp16_params_this_group.append(param) - - with torch.no_grad(): - master_param_buffer = self.fp32_param_buffer.narrow( - 0, offset, param.numel() - ).view_as(param) - master_param_buffer.copy_(param.float()) - offset += param.numel() - - master_param = nn.Parameter( - master_param_buffer, requires_grad=param.requires_grad - ) - - self.master_is_distributed[ - master_param - ] = self.model.is_distributed_parameter(param) - self.master_distribution_axis[id(master_param)] = get_distribution_axis( - param - ) - param_group["params"][i] = master_param - fp32_from_fp16_params_this_group.append(master_param) - fp32_from_fp16_paramids_this_group.append(id(master_param)) - # Reset existing state dict key to the new master param. - # We still need to recast per-param state tensors, if any, to FP32. - if param in self.optimizer.state: - self.optimizer.state[master_param] = self.optimizer.state.pop(param) - self.fp32paramid_from_fp16paramid[id(param)] = id(master_param) - elif param.type() == "torch.cuda.FloatTensor": - self.maybe_print( - "FP16_Optimizer received torch.cuda.FloatTensor with {}".format( - param.size() - ) - ) - fp32_params_this_group.append(param) - param_group["params"][i] = param - else: - raise TypeError( - "Wrapped parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " - "Received {}".format(param.type()) - ) - self.fp16_groups.append(fp16_params_this_group) - self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) - self.fp32_from_fp16_paramid_groups.append(fp32_from_fp16_paramids_this_group) - self.fp32_from_fp32_groups.append(fp32_params_this_group) - - # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors - self.optimizer.load_state_dict(self.optimizer.state_dict()) - # alternative way to cast per-param state tensors: - # self.optimizer.load_state_dict(init_state_dict) - - self.overflow = False - self.first_closure_call_this_step = True - self.master_params_created = True - - def maybe_print(self, msg): - if self.verbose: - print(msg) - - def __getstate__(self): - raise RuntimeError("FP16_Optimizer should be serialized using state_dict().") - - def __setstate__(self, state): - raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().") - - def zero_grad(self, set_grads_to_None=False): - """ - Zero fp32 and fp16 parameter grads. - """ - # In principle, only the .grad attributes of the model params need to be zeroed, - # because gradients are copied into the FP32 master params. However, we zero - # all gradients owned by the optimizer, just to be safe: - if self.shard_optimizer_state and set_grads_to_None and not self.warned_set_grads_to_none: - print("WARNING: Will not set fp16 gradients to None since shard_optimizer_state is enabled.") - self.warned_set_grads_to_none = True - - for group in self.optimizer.param_groups: - for p in group["params"]: - if set_grads_to_None: - p.grad = None - else: - if p.grad is not None: - if p.grad.grad_fn is not None: - p.grad.detach_() - else: - p.grad.requires_grad_(False) - p.grad.zero_() - - # Zero fp16 gradients owned by the model: - for fp16_group in self.fp16_groups: - for param in fp16_group: - # if shard_optimizer_state is true, do not set fp16 grads to None since - # it will be part of the contiguous buffer - if set_grads_to_None and not self.shard_optimizer_state: - param.grad = None - else: - if param.grad is not None: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() - - def _check_overflow(self): - params = [] - for group in self.fp16_groups: - for param in group: - params.append(param) - for group in self.fp32_from_fp32_groups: - for param in group: - params.append(param) - self.overflow = self.loss_scaler.has_overflow(params) - - def _update_scale(self, has_overflow=False): - self.loss_scaler.update_scale(has_overflow) - - def _master_params_to_model_params(self): - for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - master_params_to_model_params(fp16_group, fp32_from_fp16_group) - - def _model_params_to_master_params(self): - for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - model_params_to_master_params(fp16_group, fp32_from_fp16_group) - - # To consider: Integrate distributed with this wrapper by registering a hook on each variable - # that does the overflow check, gradient copy + downscale, and fp32 - # allreduce in a different stream. - def _model_grads_to_master_grads(self, loss_scale=1.0): - for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - model_grads_to_master_grads( - fp16_group, - fp32_from_fp16_group, - loss_scale=loss_scale, - params_have_main_grad=self.params_have_main_grad, - ) - - def _downscale_master(self): - if self.loss_scale != 1.0: - for group in self.optimizer.param_groups: - grads = [p.grad for p in group["params"] if p.grad is not None] - _overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier( - amp_C.multi_tensor_scale, _overflow_buf, [grads, grads], 1.0 / self.loss_scale - ) - - def clip_master_grads(self, max_norm, norm_type=2): - """ - Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. - - Args: - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the current fp32 gradients (viewed as a single vector). - - .. warning:: - Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). - """ - if not self.overflow: - fp32_params = [] - for param_group in self.optimizer.param_groups: - for param in param_group["params"]: - fp32_params.append(param) - return clip_grad_norm_fp32(fp32_params, self.master_is_distributed, self.shard_optimizer_state, max_norm, norm_type) - else: - return -1 - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - if not self.use_smp: - state_dict = {} - state_dict["loss_scaler"] = self.loss_scaler - state_dict["dynamic_loss_scale"] = self.dynamic_loss_scale - state_dict["overflow"] = self.overflow - state_dict["first_closure_call_this_step"] = self.first_closure_call_this_step - state_dict["optimizer_state_dict"] = self.optimizer.state_dict() - state_dict["fp32_from_fp16"] = self.fp32_from_fp16_groups - return state_dict - else: - return self.optimizer.state_dict() - - def load_state_dict(self, state_dict): - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - - Example:: - - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - if not self.use_smp: - # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict["loss_scaler"] - self.dynamic_loss_scale = state_dict["dynamic_loss_scale"] - self.overflow = state_dict["overflow"] - self.first_closure_call_this_step = state_dict["first_closure_call_this_step"] - self.optimizer.load_state_dict(state_dict["optimizer_state_dict"]) - # At this point, the optimizer's references to the model's fp32 parameters are up to date. - # The optimizer's hyperparameters and internal buffers are also up to date. - # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still - # out of date. There are two options. - # 1: Refresh the master params from the model's fp16 params. - # This requires less storage but incurs precision loss. - # 2: Save and restore the fp32 master copies separately. - # We choose option 2. - # - # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device - # of their associated parameters, because it's possible those buffers might not exist yet in - # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been - # constructed in the same way as the one whose state_dict we are loading, the same master params - # are guaranteed to exist, so we can just copy_() from the saved master params. - for current_group, saved_group in zip( - self.fp32_from_fp16_groups, state_dict["fp32_from_fp16"] - ): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - else: - self.optimizer.load_state_dict(state_dict) - - def reload_model_params(self): - self._model_params_to_master_params() - - def step(self, closure=None): # could add clip option. - """ - If no closure is supplied, :attr:`step` should be called after - ``fp16_optimizer_obj.backward(loss)``. - :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to - :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params - originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run - another forward pass using their model. - - If a closure is supplied, :attr:`step` may be called without a prior call to - :attr:`backward(loss)`. - This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. - However, the user should take care that any ``loss.backward()`` call within the closure - has been replaced by ``fp16_optimizer_obj.backward(loss)``. - - Args: - closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. - - Example with closure:: - - # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an - # existing pytorch optimizer. - for input, target in dataset: - def closure(): - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - # loss.backward() becomes: - optimizer.backward(loss) - return loss - optimizer.step(closure) - - .. warning:: - Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. - - .. _`ordinary Pytorch optimizer use`: - http://pytorch.org/docs/master/optim.html#optimizer-step-closure - """ - - scale = self.loss_scaler.loss_scale - self._update_scale(self.overflow) - - if self.overflow: - self.maybe_print( - "OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}".format( - scale, self.loss_scale - ) - ) - return - - if closure is not None: - retval = self._step_with_closure(closure) - else: - retval = self.optimizer.step() - - self._master_params_to_model_params() - - return retval - - def _step_with_closure(self, closure): - def wrapped_closure(): - # helpful for debugging - # print("Calling wrapped_closure, first_closure_call_this_step = {}" - # .format(self.first_closure_call_this_step)) - if self.first_closure_call_this_step: - # We expect that the fp16 params are initially fresh on entering self.step(), - # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() - # is called within self.optimizer.step(). - self.first_closure_call_this_step = False - else: - # If self.optimizer.step() internally calls wrapped_closure more than once, - # it may update the fp32 params after each call. However, self.optimizer - # doesn't know about the fp16 params at all. If the fp32 params get updated, - # we can't rely on self.optimizer to refresh the fp16 params. We need - # to handle that manually: - self._master_params_to_model_params() - # Our API expects the user to give us ownership of the backward() call by - # replacing all calls to loss.backward() with optimizer.backward(loss). - # This requirement holds whether or not the call to backward() is made within a closure. - # If the user is properly calling optimizer.backward(loss) within "closure," - # calling closure() here will give the fp32 master params fresh gradients - # for the optimizer to play with, so all wrapped_closure needs to do is call - # closure() and return the loss. - temp_loss = closure() - while self.overflow: - scale = self.loss_scaler.loss_scale - self._update_scale(self.overflow) - self.maybe_print( - "OVERFLOW within closure! Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(scale, self.loss_scale) - ) - temp_loss = closure() - return temp_loss - - retval = self.optimizer.step(wrapped_closure) - - self.first_closure_call_this_step = True - - return retval - - def backward(self, loss, update_master_grads=True, retain_graph=False): - """ - :attr:`backward` performs the following conceptual steps: - - 1. fp32_loss = loss.float() (see first Note below) - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). - 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. - 5. Finally, master grads are divided by loss_scale. - - In this way, after :attr:`backward`, the master params have fresh gradients, - and :attr:`step` may be called. - - .. note:: - :attr:`backward` internally converts the loss to fp32 before applying the loss scale. - This provides some additional safety against overflow if the user has supplied an - fp16 loss value. - However, for maximum overflow safety, the user should - compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to - :attr:`backward`. - - .. warning:: - The gradients found in a model's leaves after the call to - :attr:`backward` should not be regarded as valid in general, - because it's possible - they have been scaled (and in the case of dynamic loss scaling, - the scale factor may change over time). - If the user wants to inspect gradients after a call to :attr:`backward`, - only the master gradients should be regarded as valid. These can be retrieved via - :attr:`inspect_master_grad_data()`. - - Args: - loss: The loss output by the user's model. loss may be either float or half (but see first Note above). - update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. - retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). - - Example:: - - # Ordinary operation: - optimizer.backward(loss) - - # Naive operation with multiple losses (technically valid, but less efficient): - # fp32 grads will be correct after the second call, but - # the first call incurs an unnecessary fp16->fp32 grad copy. - optimizer.backward(loss1) - optimizer.backward(loss2) - - # More efficient way to handle multiple losses: - # The fp16->fp32 grad copy is delayed until fp16 grads from all - # losses have been accumulated. - optimizer.backward(loss1, update_master_grads=False) - optimizer.backward(loss2, update_master_grads=False) - optimizer.update_master_grads() - """ - # To consider: try multiple backward passes using retain_grad=True to find - # a loss scale that works. After you find a loss scale that works, do a final dummy - # backward pass with retain_graph=False to tear down the graph. Doing this would avoid - # discarding the iteration, but probably wouldn't improve overall efficiency. - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - if update_master_grads: - self.update_master_grads() - - def update_master_grads(self): - """ - Copy the ``.grad`` attribute from stored references to fp16 parameters to - the ``.grad`` attribute of the fp32 master parameters that are directly - updated by the optimizer. :attr:`update_master_grads` only needs to be called if - ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. - """ - if self.dynamic_loss_scale: - self._check_overflow() - if self.overflow: - return - self._model_grads_to_master_grads(self.loss_scale) - # self._downscale_master() - - def inspect_master_grad_data(self): - """ - When running with :class:`FP16_Optimizer`, - ``.grad`` attributes of a model's fp16 leaves should not be - regarded as truthful, because they might be scaled. - After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, - the fp32 master params' ``.grad`` - attributes will contain valid gradients properly divided by the loss scale. However, - because :class:`FP16_Optimizer` flattens some parameters, accessing them may be - nonintuitive. :attr:`inspect_master_grad_data` - allows those gradients to be viewed with shapes corresponding to their associated model leaves. - - Returns: - List of lists (one list for each parameter group). The list for each parameter group - is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. - """ - if self.overflow: - print( - "Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " - "Gradients are currently invalid (may be inf, nan, or stale). Returning None." - ) - return None - else: - # The optimizer owns only references to master params. - master_grads_data = [] - for param_group in self.optimizer.param_groups: - master_grads_this_group = [] - for param in param_group["params"]: - if param.grad is not None: - master_grads_this_group.append(param.grad.data) - else: - master_grads_this_group.append(None) - master_grads_data.append(master_grads_this_group) - return master_grads_data - - # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" - - def _get_loss_scale(self): - return self.loss_scaler.loss_scale - - def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value - - loss_scale = property(_get_loss_scale, _set_loss_scale) - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16util.py b/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16util.py deleted file mode 100644 index 3b3c91866e..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16util.py +++ /dev/null @@ -1,406 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from torch.autograd import Variable -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C -import smdistributed.modelparallel.torch as smp -from smdistributed.modelparallel.torch.state_mod import state as smp_state - - -class tofp16(nn.Module): - """ - Utility module that implements:: - - def forward(self, input): - return input.half() - """ - - def __init__(self): - super(tofp16, self).__init__() - - def forward(self, input): - return input.half() - - -def BN_convert_float(module): - """ - Utility function for network_to_half(). - - Retained for legacy purposes. - """ - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: - module.float() - for child in module.children(): - BN_convert_float(child) - return module - - -def network_to_half(network): - """ - Convert model to half precision in a batchnorm-safe way. - - Retained for legacy purposes. It is recommended to use FP16Model. - """ - return nn.Sequential(tofp16(), BN_convert_float(network.half())) - - -def convert_module(module, dtype): - """ - Converts a module's immediate parameters and buffers to dtype. - """ - for param in module.parameters(recurse=False): - if param is not None: - if param.data.dtype.is_floating_point: - param.data = param.data.to(dtype=dtype) - if param._grad is not None and param._grad.data.dtype.is_floating_point: - param._grad.data = param._grad.data.to(dtype=dtype) - - for buf in module.buffers(recurse=False): - if buf is not None and buf.data.dtype.is_floating_point: - buf.data = buf.data.to(dtype=dtype) - - -def convert_network(network, dtype): - """ - Converts a network's parameters and buffers to dtype. - """ - for module in network.modules(): - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: - continue - convert_module(module, dtype) - return network - - -class FP16Model(nn.Module): - """ - Convert model to half precision in a batchnorm-safe way. - """ - - def __init__(self, network): - super(FP16Model, self).__init__() - self.network = convert_network(network, dtype=torch.half) - - def forward(self, *inputs): - inputs = tuple(t.half() for t in inputs) - return self.network(*inputs) - - -def backwards_debug_hook(grad): - raise RuntimeError("master_params recieved a gradient in the backward pass!") - - -def prep_param_lists(model, flat_master=False): - """ - Creates a list of FP32 master parameters for a given model, as in - `Training Neural Networks with Mixed Precision: Real Examples`_. - - Args: - model (torch.nn.Module): Existing Pytorch model - flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. - Returns: - A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. - - Example:: - - model_params, master_params = prep_param_lists(model) - - .. warning:: - Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. - - .. _`Training Neural Networks with Mixed Precision: Real Examples`: - http://on-demand.gputechconf.com/gtc/2018/video/S81012/ - """ - model_params = [param for param in model.parameters() if param.requires_grad] - - if flat_master: - # Give the user some more useful error messages - try: - # flatten_dense_tensors returns a contiguous flat array. - # http://pytorch.org/docs/master/_modules/torch/_utils.html - master_params = _flatten_dense_tensors([param.data for param in model_params]).float() - except BaseException: - print("Error in prep_param_lists: model may contain a mixture of parameters " - "of different types. Use flat_master=False, or use F16_Optimizer.") - raise - master_params = torch.nn.Parameter(master_params) - master_params.requires_grad = True - # master_params.register_hook(backwards_debug_hook) - if master_params.grad is None: - master_params.grad = master_params.new(*master_params.size()) - return model_params, [master_params] - else: - master_params = [param.clone().float().detach() for param in model_params] - for param in master_params: - param.requires_grad = True - return model_params, master_params - - -def model_grads_to_master_grads(model_params, master_params, flat_master=False, loss_scale=1.0, params_have_main_grad=False): - """ - Copy model gradients to master gradients. - - Args: - model_params: List of model parameters created by :func:`prep_param_lists`. - master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. - """ - if flat_master: - # The flattening may incur one more deep copy than is necessary. - master_params[0].grad.data.copy_( - _flatten_dense_tensors([p.grad.data for p in model_params])) - else: - for model, master in zip(model_params, master_params): - if model.device.type == "cpu": - continue - if model.grad is not None: - if master.grad is None: - if params_have_main_grad: - # If gradient_as_bucket_view is False, this will be a copy - master.grad = model.grad.float() - else: - master.grad = Variable(master.data.new(*master.data.size())) - else: - master.grad = None - model_grads = [p.grad for p in model_params if p.grad is not None] - master_grads = [p.grad for p in master_params if p.grad is not None] - if len(model_grads) == 0 or len(master_grads) == 0: - return - _overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(amp_C.multi_tensor_scale, - _overflow_buf, - [model_grads, master_grads], - 1.0/loss_scale) - - -def master_params_to_model_params(model_params, master_params, flat_master=False): - """ - Copy master parameters to model parameters. - - Args: - model_params: List of model parameters created by :func:`prep_param_lists`. - master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. - """ - if flat_master: - for model, master in zip(model_params, - _unflatten_dense_tensors(master_params[0].data, model_params)): - model.data.copy_(master) - else: - for model, master in zip(model_params, master_params): - if model.device.type == "cpu": - continue - model.data.copy_(master.data) - -def model_params_to_master_params(model_params, master_params, flat_master=False): - """ - Copy model params to master params - """ - if flat_master: - raise ValueError("Not supported") - else: - for model, master in zip(model_params, master_params): - if model.device.type == "cpu": - continue - master.data.copy_(model.data) - - -# Backward compatibility fixes - - -def to_python_float(t): - if hasattr(t, 'item'): - return t.item() - else: - return t[0] - - -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) - - -def get_tp_merged_fp32_from_fp16_param_groups(optimizer, cpu_fp32_from_fp16_groups): - def _merge_param_group_tp_group(group_idx, param_group): - result_fp32_from_fp16_param_group = [] - param_name_group = {} - for i, param in enumerate(param_group): - # for each param, obtain param_name from param using two dicts above for tp_rank 0 - param_index = param_id_to_index_tp_group[rank_0][ - fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i] - ] - param_name = param_index_to_name_tp_group[rank_0][param_index] - # obtain distribution axis for the param and check if its distributed - # axis = master_distribution_axis_tp_rank_0[fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]] - axis = master_distribution_axis_tp_rank_0.get( - fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i], None - ) - if axis is not None: - tensors = [] - for r in range(smp.tp_size()): - # if distributed, for each rank, obtain param id from index using above two dicts - param_index_r = param_name_to_index_tp_group[r][param_name] - param_id_r = param_index_to_id_tp_group[r][param_index_r] - - # search param id in fp32_from_fp16_groups_param_ids and find the index. - group_param_idx = fp32_from_fp16_paramid_groups_tp_group[r][group_idx].index( - param_id_r - ) - # use the param corresponding to the index from fp32_from_fp16_groups for concatenation along axis - tensors.append( - fp32_from_fp16_param_groups_tp_group[r][group_idx][group_param_idx] - ) - result_fp32_from_fp16_param_group.append(torch.cat(tensors, axis)) - else: - # if not distributed set tp_rank 0 param as the param - result_fp32_from_fp16_param_group.append(param) - param_name_group[param_name] = i - return result_fp32_from_fp16_param_group, param_name_group - - # get param_index_to_name all and param_name_to_index_all - param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group - param_name_to_index_tp_group = smp_state.param_name_to_index_tp_group - # get mapping of param_id_to_index_all and param_index_to_id_all - param_id_to_index = optimizer._param_id_to_index() - param_id_to_index_tp_group = smp.allgather(param_id_to_index, smp.TP_GROUP) - param_index_to_id_tp_group = _get_param_index_to_id(param_id_to_index_tp_group) - # allgather all param ids and all params for fp32_from_fp16_groups - fp32_from_fp16_paramid_groups = optimizer.fp32_from_fp16_paramid_groups - fp32_from_fp16_paramid_groups_tp_group = smp.allgather( - fp32_from_fp16_paramid_groups, smp.TP_GROUP - ) - fp32_from_fp16_param_groups_tp_group = smp.allgather(cpu_fp32_from_fp16_groups, smp.TP_GROUP) - # broadcast distribution axis from tp_rank 0 to all tp_ranks - master_distribution_axis_tp_rank_0 = None - if smp.tp_rank() == 0: - master_distribution_axis_tp_rank_0 = optimizer.master_distribution_axis - smp.broadcast(master_distribution_axis_tp_rank_0, smp.TP_GROUP) - else: - master_distribution_axis_tp_rank_0 = smp.recv_from(0, smp.RankType.TP_RANK) - - result_fp32_from_fp16_param_groups = [] - param_name_groups = [] - rank_0 = 0 - # iterate through all the params for tp_group_fp32_from_fp16_groups[rank_0] - for group_idx, param_group in enumerate(fp32_from_fp16_param_groups_tp_group[rank_0]): - result_fp32_from_fp16_param_group, param_name_group = _merge_param_group_tp_group( - group_idx, param_group - ) - result_fp32_from_fp16_param_groups.append(result_fp32_from_fp16_param_group) - param_name_groups.append(param_name_group) - return result_fp32_from_fp16_param_groups, param_name_groups - - -def get_pp_merged_fp32_from_fp16_param_groups( - optimizer, fp32_from_fp16_groups, param_name_groups=None -): - pp_group_fp32_from_fp16_groups = smp.allgather(fp32_from_fp16_groups, smp.PP_GROUP) - if param_name_groups is not None: - index_to_param_name_groups = [] - # obtain index_to_param_name mapping across tp_group - for param_name_group in param_name_groups: - index_to_param_name = {} - for param_name, index in param_name_group.items(): - index_to_param_name[index] = param_name - index_to_param_name_groups.append(index_to_param_name) - # allgather the index_to_param_name_groups across the pp_group - pp_index_to_param_name_groups = smp.allgather(index_to_param_name_groups, smp.PP_GROUP) - else: - raise ValueError("Merging is not supported when param_name_groups is None") - - pp_merged_fp32_from_fp16_groups = [] - result_param_groups = [] - - # iterate through all the groups for rank 0 - for group_idx in range(len(pp_group_fp32_from_fp16_groups[0])): - merged = [] - start_idx = 0 - result_param_group = {} - # for each group iterate through all ranks and merge the param groups across pp_ranks - for rank, group in enumerate(pp_group_fp32_from_fp16_groups): - cur_g = group[group_idx] - start_idx += len(merged) - for i, _ in enumerate(cur_g): - param_name = pp_index_to_param_name_groups[rank][group_idx][i] - if param_name in result_param_group: - raise ValueError( - "same param_name present in the param_groups of different pipeline parallel partitions" - ) - result_param_group[param_name] = i + start_idx - merged.extend(cur_g) - pp_merged_fp32_from_fp16_groups.append(merged) - result_param_groups.append(result_param_group) - return pp_merged_fp32_from_fp16_groups, result_param_groups - - -def _get_param_index_to_id(param_id_to_index_tp_group): - param_index_to_id_tp_group = [] - for param_id_to_index_map in param_id_to_index_tp_group: - param_index_to_id_map = {} - for param_id, param_index in param_id_to_index_map.items(): - param_index_to_id_map[param_index] = param_id - param_index_to_id_tp_group.append(param_index_to_id_map) - return param_index_to_id_tp_group - - -def register_optimizer_hooks(model): - def param_name_to_index(self): - param_id_to_index = self._param_id_to_index() - name_to_index = {} - if self.redefined_params: - param_gen = model.virtual_named_parameters() - else: - param_gen = model.named_parameters() - for name, param in param_gen: - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - name_to_index[name] = param_id_to_index[param_id] - return name_to_index - - def _param_index_to_param_local(self): - param_id_to_index = self._param_id_to_index() - param_index_to_param = {} - - if not model: - return param_index_to_param - - if self.redefined_params: - param_gen = model.virtual_named_parameters() - else: - param_gen = model.named_parameters() - for name, param in param_gen: - fp16_param_id = id(param) - if fp16_param_id in self.fp32paramid_from_fp16paramid: - param_id = self.fp32paramid_from_fp16paramid[fp16_param_id] - else: - param_id = fp16_param_id - if param_id in param_id_to_index: - param_index_to_param[param_id_to_index[param_id]] = param - - return param_index_to_param - - def hook_fn(model, optimizer): - from functools import partial - - optimizer.param_name_to_index = partial(param_name_to_index, optimizer) - optimizer._param_index_to_param_local = partial(_param_index_to_param_local, optimizer) - - model.register_post_partition_hook(hook_fn) diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/loss_scaler.py b/training/distributed_training/pytorch/model_parallel/gpt2/fp16/loss_scaler.py deleted file mode 100644 index fdae2f7241..0000000000 --- a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/loss_scaler.py +++ /dev/null @@ -1,271 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C - -import smdistributed.modelparallel.torch as smp - - -def to_python_float(t): - if hasattr(t, 'item'): - return t.item() - else: - return t[0] - - -class LossScaler: - """ - Class that manages a static loss scale. This class is intended to interact with - :class:`FP16_Optimizer`, and should not be directly manipulated by the user. - - Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to - :class:`FP16_Optimizer`'s constructor. - - Args: - scale (float, optional, default=1.0): The loss scale. - """ - - def __init__(self, model, shard_optimizer_state, scale=1, use_smp=False): - self.cur_scale = scale - self.model = model - self.use_smp = use_smp - self.shard_optimizer_state = shard_optimizer_state - - # `params` is a list / generator of torch.Variable - def has_overflow(self, params): - return False - - # `x` is a torch.Tensor - def _has_inf_or_nan(x): - return False - - def update_scale(self, overflow): - pass - - @property - def loss_scale(self): - return self.cur_scale - - def scale_gradient(self, module, grad_in, grad_out): - _overflow_buf = torch.cuda.IntTensor([0]) - - multi_tensor_applier(amp_C.multi_tensor_scale, - _overflow_buf, - [grad_in, grad_in], - self.loss_scale) - return grad_in - - def backward(self, loss, retain_graph=False): - scaled_loss = loss * self.loss_scale - if self.use_smp: - self.model.backward(scaled_loss) - else: - scaled_loss.backward() - - -class DynamicLossScaler: - """ - Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` - indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of - :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` - operates, because the default options can be changed using the - the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. - - Loss scaling is designed to combat the problem of underflowing gradients encountered at long - times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss - scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are - encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has - occurred. - :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, - and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. - If a certain number of iterations occur without overflowing gradients detected, - :class:`DynamicLossScaler` increases the loss scale once more. - In this way :class:`DynamicLossScaler` attempts to "ride the edge" of - always using the highest loss scale possible without incurring overflow. - - Args: - init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` - scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. - scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. - """ - - def __init__(self, - model, - shard_optimizer_state, - init_scale=2**32, - scale_factor=2., - scale_window=1000, - min_scale=1, - delayed_shift=1, - consecutive_hysteresis=False, use_smp=False): - self.model = model - self.shard_optimizer_state = shard_optimizer_state - self.cur_scale = init_scale - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = scale_factor - self.scale_window = scale_window - self.min_scale = min_scale - self.delayed_shift = delayed_shift - self.cur_hysteresis = delayed_shift - self.consecutive_hysteresis = consecutive_hysteresis - self.use_smp = use_smp - - # `params` is a list / generator of torch.Variable - def has_overflow_serial(self, params): - for p in params: - if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): - return True - - return False - - def has_overflow(self, params): - overflow = self.has_overflow_serial(params) - # Since each model parallel GPU carries only part of the model, - # make sure overflow flag is synced across all the model parallel GPUs - overflow_gpu = torch.cuda.ByteTensor([overflow]) - group = smp.get_world_process_group() if self.shard_optimizer_state else smp.get_mp_process_group() - torch.distributed.all_reduce(overflow_gpu, - op=torch.distributed.ReduceOp.MAX, - group=group) - overflow = overflow_gpu[0].item() - return bool(overflow) - - # `x` is a torch.Tensor - - def _has_inf_or_nan(x): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - return False - - # `overflow` is boolean indicating whether the gradient overflowed - def update_scale(self, overflow): - - if not hasattr(self, 'min_scale'): - self.min_scale = 1 - if not hasattr(self, 'delayed_shift'): - self.delayed_shift = 1 - if not hasattr(self, 'cur_hysteresis'): - self.cur_hysteresis = 1 - if not hasattr(self, 'consecutive_hysteresis'): - self.consecutive_hysteresis = True - if overflow: - # self.cur_scale /= self.scale_factor - if self.delayed_shift == 1 or self.cur_hysteresis == 1: - self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale) - else: - self.cur_hysteresis -= 1 - self.last_overflow_iter = self.cur_iter - else: - if self.consecutive_hysteresis: - self.cur_hysteresis = self.delayed_shift - if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: - if not self.consecutive_hysteresis: - self.cur_hysteresis = self.delayed_shift - self.cur_scale *= self.scale_factor - self.cur_iter += 1 - - @property - def loss_scale(self): - return self.cur_scale - - def scale_gradient(self, module, grad_in, grad_out): - _overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(amp_C.multi_tensor_scale, - _overflow_buf, - [grad_in, grad_in], - self.loss_scale) - return grad_in - - def backward(self, loss, retain_graph=False): - scaled_loss = loss * self.loss_scale - if self.use_smp: - self.model.backward(scaled_loss) - else: - scaled_loss.backward() - - -############################################################## -# Example usage below here -- assuming it's in a separate file -############################################################## -""" -TO-DO separate out into an example. -if __name__ == "__main__": - import torch - from torch.autograd import Variable - from dynamic_loss_scaler import DynamicLossScaler - - # N is batch size; D_in is input dimension; - # H is hidden dimension; D_out is output dimension. - N, D_in, H, D_out = 64, 1000, 100, 10 - - # Create random Tensors to hold inputs and outputs, and wrap them in Variables. - x = Variable(torch.randn(N, D_in), requires_grad=False) - y = Variable(torch.randn(N, D_out), requires_grad=False) - - w1 = Variable(torch.randn(D_in, H), requires_grad=True) - w2 = Variable(torch.randn(H, D_out), requires_grad=True) - parameters = [w1, w2] - - learning_rate = 1e-6 - optimizer = torch.optim.SGD(parameters, lr=learning_rate) - loss_scaler = DynamicLossScaler() - - for t in range(500): - y_pred = x.mm(w1).clamp(min=0).mm(w2) - loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale - print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) - print('Iter {} scaled loss: {}'.format(t, loss.data[0])) - print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) - - # Run backprop - optimizer.zero_grad() - loss.backward() - - # Check for overflow - has_overflow = DynamicLossScaler.has_overflow(parameters) - - # If no overflow, unscale grad and update as usual - if not has_overflow: - for param in parameters: - param.grad.data.mul_(1. / loss_scaler.loss_scale) - optimizer.step() - # Otherwise, don't do anything -- ie, skip iteration - else: - print('OVERFLOW!') - - # Update loss scale for next iteration - loss_scaler.update_scale(has_overflow) - -""" diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/memory_tracker.py b/training/distributed_training/pytorch/model_parallel/gpt2/memory_tracker.py index 39a3550ab0..329926a26e 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/memory_tracker.py +++ b/training/distributed_training/pytorch/model_parallel/gpt2/memory_tracker.py @@ -1,5 +1,29 @@ +import psutil +import os + import smdistributed.modelparallel.torch as smp import torch +try: + from py3nvml import py3nvml +except ImportError: + py3nvml = None + +dtype_to_bit = { +torch.float32 : 32, +torch.float64 : 64, +torch.float16: 16, +torch.bfloat16: 16, +torch.uint8: 8, +torch.int8: 8, +torch.int16: 16, +torch.int32: 32, +torch.int64: 64, +torch.bool: 1 +} + +process = psutil.Process(os.getpid()) +base_mem_usage = process.memory_info().data +last_mem_usage = base_mem_usage def memory_status(msg="", reset_max=True, sync=True): @@ -15,6 +39,15 @@ def memory_status(msg="", reset_max=True, sync=True): if rdp_rank != 0: return + if py3nvml != None: + py3nvml.nvmlInit() + handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank) + info = py3nvml.nvmlDeviceGetMemoryInfo(handle) + total_used = info.used / 1024**3 + total_used_str = f"Totally used GPU memory: {total_used}" + else: + total_used_str = "" + alloced = torch.cuda.memory_allocated(device=local_rank) max_alloced = torch.cuda.max_memory_allocated(device=local_rank) cached = torch.cuda.memory_reserved(device=local_rank) @@ -30,8 +63,49 @@ def memory_status(msg="", reset_max=True, sync=True): f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', f'device={local_rank} ' f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} ' - f'cache {cached:0.4f} max_cached {max_cached:0.4f}' + f'cache {cached:0.4f} max_cached {max_cached:0.4f} ' + f'{total_used_str}' ) if reset_max: torch.cuda.reset_max_memory_cached() - torch.cuda.reset_max_memory_allocated() \ No newline at end of file + torch.cuda.reset_max_memory_allocated() + if py3nvml != None: + py3nvml.nvmlShutdown() + +def memory_status_cpu(msg=""): + import gc + global last_mem_usage + global base_mem_usage + rdp_rank = smp.rdp_rank() + gc.collect() + gc.collect() + gc.collect() + objects = gc.get_objects() + tensors = [obj for obj in objects if isinstance(obj, torch.Tensor) and not obj.is_cuda] + torch_usage = 0 + for t in tensors: + torch_usage += t.numel() * dtype_to_bit[t.dtype] + #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes + current_usage = process.memory_info().data + total_usage = current_usage - base_mem_usage + usage_change = current_usage - last_mem_usage + last_mem_usage = current_usage + + torch_usage /= 1024**3 + total_usage /= 1024**3 + usage_change /= 1024**3 + base_usage = base_mem_usage / 1024**3 + + rank = smp.rank() + tp_rank = smp.tp_rank() + pp_rank = smp.pp_rank() + rdp_rank = smp.rdp_rank() + local_rank = smp.local_rank() + if rdp_rank != 0: + return + + print( + f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', + f'device={local_rank} ' + f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}' + ) \ No newline at end of file diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt b/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt index 2a1fd1a21e..67d35f169d 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt +++ b/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt @@ -4,5 +4,6 @@ sagemaker sagemaker-experiments scipy torchnet +transformers==4.21.0 smdebug humanize diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/sharded_data_parallel_checkpoint.py b/training/distributed_training/pytorch/model_parallel/gpt2/sharded_data_parallel_checkpoint.py new file mode 100644 index 0000000000..e9e7ebd79e --- /dev/null +++ b/training/distributed_training/pytorch/model_parallel/gpt2/sharded_data_parallel_checkpoint.py @@ -0,0 +1,240 @@ +import torch +import glob +import math +import os +import re +import gc +from collections import OrderedDict + +# load to cpu +device = torch.device('cpu') +smp_prefix = "module." + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [ atoi(c) for c in re.split(r'(\d+)', text) ] + +def get_model_state_file(checkpoint_dir): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + file = os.path.join(checkpoint_dir, "model_0.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + +def get_optim_files(checkpoint_dir): + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "optimizer_*.pt")), key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + +def get_user_content_file(checkpoint_dir): + file = os.path.join(checkpoint_dir, "user_content.pt") + if not os.path.exists(file): + raise FileNotFoundError(f"can't find user content file at '{file}'") + return file + +def parse_model_state(model_file, user_content_file, dtype): + state_dict = torch.load(model_file, map_location=device) + user_content = torch.load(user_content_file, map_location=device) + + if "buffer_names" not in user_content: + raise ValueError(f"{user_content_file} miss buffer_names to reconstruct the full state") + if "param_shapes" not in user_content: + raise ValueError(f"{user_content_file} miss param_shapes to reconstruct the full state") + buffer_names = user_content["buffer_names"] + param_shapes = user_content["param_shapes"] + + # recover just the buffers while restoring them to the specified dtype + buffers = { + k: v.to(dtype) + for k, + v in state_dict["module"].items() if k in buffer_names + } + + return buffers, param_shapes + +def parse_optim_states(files, checkpoint_dir, dtype): + total_files = len(files) + state_dicts = [] + sharded_data_parallel_size = None + # param_shapes = None + fp32_groups_key = None + for i, f in enumerate(files): + states = torch.load(f, map_location=device) + if i == 0: + sharded_data_parallel_size = states["partition_count"] + states["fp32_flat_groups"] = [group.to(dtype) for group in states["fp32_flat_groups"]] + state_dicts.append(states["fp32_flat_groups"]) + + if type(sharded_data_parallel_size) is list: + sharded_data_parallel_size = max(sharded_data_parallel_size) + + if sharded_data_parallel_size != total_files: + raise ValueError( + f"Expected {sharded_data_parallel_size} of 'optimizer_*.pt' under '{checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + flat_groups = [ + torch.cat(state_dicts[i], + 0) for i in range(len(state_dicts)) + ] + + return sharded_data_parallel_size, flat_groups + +def partitioned_param_info(unpartitioned_numel, sharded_data_parallel_size): + remainder = unpartitioned_numel % sharded_data_parallel_size + padding_numel = (sharded_data_parallel_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / sharded_data_parallel_size) + return partitioned_numel, padding_numel + +def get_full_state_dict_from_sharded_data_parallel_checkpoint(checkpoint_dir, dtype=torch.float32, tag=None, remove_smp_prefix=True): + """ + Returns full state_dict reconstructed from sharded data parallel checkpoint + + Args: + - checkpoint_dir: path to the sharded data parallel checkpoint folder (where the optimizer files are) + - dtype: the dtype of the output full checkpoint + - tag: the checkpoint tag, if not specified will read the newest checkpoint + - remove_smp_prefix: remove the "module." prefix created by smp + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'newest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'newest' file at {latest_path}") + + checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + print(f"Processing checkpoint '{checkpoint_dir}'") + + optim_files = get_optim_files(checkpoint_dir) + sharded_data_parallel_size, flat_groups = parse_optim_states(optim_files, checkpoint_dir, dtype) + + model_file = get_model_state_file(checkpoint_dir) + user_content_file = get_user_content_file(checkpoint_dir) + buffers, param_shapes = parse_model_state(model_file, user_content_file, dtype) + + gc.collect() + avail_numel = flat_groups[0].numel() * sharded_data_parallel_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + # params + offset = 0 + total_numel = 0 + total_params = 0 + + state_dict = OrderedDict() + state_dict.update(buffers) + + for name, shape in param_shapes.items(): + if remove_smp_prefix and name.startswith(smp_prefix): + name = name[len(smp_prefix):] + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = partitioned_param_info(unpartitioned_numel, sharded_data_parallel_size) + + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # memory usage doubles here + state_dict[name] = torch.cat( + tuple(flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(sharded_data_parallel_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= sharded_data_parallel_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + +def get_param_shapes(model, optimizer): + """Returns a dict of name to shape mapping, only for the flattened weights saved by the + optimizer. the names are exactly as in state_dict. The order is absolutely important, since + the saved data is just flattened data with no identifiers and requires reconstruction in the + same order it was saved. + + We can't rely on module.named_parameters() to get the saved tensors, as some params + will be missing and others unsaved and then it'd be impossible to reconstruct state_dict + from the flattened weights. + """ + param_group_shapes = [] + cnt = 0 + numel = 0 + + bit16_groups = optimizer.fp16_groups + param_names = {param: name for name, param in model.module.named_parameters()} + + for bit16_group in bit16_groups: + param_shapes = OrderedDict() + for param in bit16_group: + cnt += 1 + numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() + shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape + if param not in param_names: + raise ValueError(f"failed to find optimizer param in named params") + name = param_names[param] + param_shapes[name] = shape + + param_group_shapes.append(param_shapes) + + return param_group_shapes + +def get_buffer_names(model): + buffer_names = [] + + # we save buffer names so that we could extract later the real buffers from the saved + # state_dict["module"] in the non-zero checkpoint - the buffers are already there but they + # are intermixed with param placeholders + + # have to traverse the tree to be able to skip non-persistent buffers + def get_layer_named_buffers(module, prefix=""): + for name, buf in module.named_buffers(recurse=False): + if buf is not None and name not in module._non_persistent_buffers_set: + buffer_names.append(prefix + name) + + for name, child in module.named_children(): + if child is not None: + get_layer_named_buffers(child, prefix + name + ".") + + get_layer_named_buffers(model.module, prefix="") + + return buffer_names \ No newline at end of file diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb b/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb index 0668dda4dd..b8a7ab85dc 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb +++ b/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Train GPT-2 with PyTorch 1.8.1 and Tensor Parallelism Using the SageMaker Model Parallelism Library" + "# Train GPT-2 with PyTorch 1.12 and Tensor Parallelism Using the SageMaker Model Parallelism Library" ] }, { @@ -13,16 +13,17 @@ "source": [ "This notebook walks you through how to use the SageMaker model parallelism (SMP) library. You'll learn how to train the GPT-2 model with SageMaker's model parallelism.\n", "\n", - "The GPT-2 model was proposed by OpenAI in paper [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). The original GPT-2 is a large transformer-based language model with 1.5 billion parameters. In this notebook, you can experiment with the model parameters to achieve different model sizes. This notebook uses the [Hugging Face Transformers GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html) implementation with the SMP integration. Currently, SMP only supports Hugging Face Transformers version 4.4.2.\n", + "The GPT-2 model was proposed by OpenAI in paper [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). The original GPT-2 is a large transformer-based language model with 1.5 billion parameters. In this notebook, you can experiment with the model parameters to achieve different model sizes. This notebook uses the [Hugging Face Transformers GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html) implementation with the SMP integration. \n", "\n", "This notebook depends on the following files and folders:\n", "\n", "- `train_gpt_simple.py`: This is an entrypoint script that is passed to the Hugging Face estimator in the notebook instructions. This script is responsible for end to end training of the GPT-2 model with SMP. The script has additional comments at places where the SMP API is used.\n", - "- `fp16`: This folder is used for 16-bit float training, which contains a fp16 optimizer and various fp16 utilities.\n", "- `data_pipeline.py`: This contains the datapipeline function to prepare the training data.\n", "- `learining_rate.py`: This contains the functions for learning rate schedule.\n", "- `requirements.txt`: This will install the dependencies, like the right version of huggingface transformers.\n", "- `data_prep_512.py`: This will download and preprocess the openwebtext dataset.\n", + "- `memory_tracker.py`: This contains the functions to track memory usage.\n", + "- `sharded_data_parallel_checkpoint.py`: This contains checkpoint util functions for sharded data parallelism\n", "\n", "### Additional Resources\n", "If you are a new user of Amazon SageMaker, you may find the following helpful to learn more about SMP and using SageMaker with PyTorch.\n", @@ -77,7 +78,7 @@ "import boto3\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", - "from sagemaker.huggingface import HuggingFace\n", + "from sagemaker.pytorch import PyTorch\n", "\n", "role = (\n", " get_execution_role()\n", @@ -322,7 +323,6 @@ " \"activation_checkpointing\": 1,\n", " \"activation_strategy\": \"each\",\n", " \"optimize\": \"speed\",\n", - "\n", " }\n", ")\n", "\n", @@ -527,9 +527,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Create a SageMaker HuggingFace Estimator\n", + "### Create a SageMaker PyTorch Estimator\n", "\n", - "The following cell constructs a HuggingFace estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `train_gpt_simple.py` file." + "The following cell constructs a PyTorch estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `train_gpt_simple.py` file." ] }, { @@ -544,7 +544,7 @@ " kwargs[\"security_group_ids\"] = [fsx_security_group_id]\n", " kwargs[\"subnets\"] = [fsx_subnet]\n", "\n", - "smp_estimator = HuggingFace(\n", + "smp_estimator = PyTorch(\n", " entry_point=\"train_gpt_simple.py\",\n", " source_dir=os.getcwd(),\n", " role=role,\n", @@ -569,18 +569,16 @@ " \"partitions\": hyperparameters[\"pipeline_parallel_degree\"],\n", " \"shard_optimizer_state\": hyperparameters[\"shard_optimizer_state\"] > 0,\n", " \"prescaled_batch\": hyperparameters[\"prescaled_batch\"] > 0,\n", - " \"fp16_params\": hyperparameters[\"fp16\"] > 0,\n", + " \"fp16\": hyperparameters[\"fp16\"] > 0,\n", " \"optimize\": hyperparameters[\"optimize\"],\n", " \"auto_partition\": False if hyperparameters[\"manual_partition\"] else True,\n", " \"default_partition\": 0,\n", - " \"fp16_params\": hyperparameters[\"fp16\"] > 0,\n", " \"optimize\": hyperparameters[\"optimize\"],\n", " },\n", " }\n", " },\n", " },\n", - " pytorch_version=\"1.10\",\n", - " transformers_version=\"4.17\",\n", + " framework_version=\"1.12\",\n", " py_version=\"py38\",\n", " output_path=s3_output_location,\n", " checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,\n", diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py b/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py index 5c44c90ed0..ac2abb3a00 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py +++ b/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py @@ -15,11 +15,12 @@ import torch.utils.data import transformers from data_pipeline import create_pretraining_dataloader -from fp16 import FP16_Module, FP16_Optimizer, load_fp16_optimizer, save_fp16_optimizer from learning_rates import AnnealingLR +from memory_tracker import memory_status, memory_status_cpu +from sharded_data_parallel_checkpoint import get_buffer_names, get_param_shapes from smdistributed.modelparallel.torch.nn import FusedLayerNorm as LayerNorm from smdistributed.modelparallel.torch.nn.huggingface.gpt2 import ( - translate_hf_state_dict_to_smdistributed, + translate_hf_state_dict_to_smdistributed_gpt2, translate_state_dict_to_hf_gpt2, ) from torch import optim @@ -36,6 +37,7 @@ ) from transformers.trainer_utils import is_main_process +logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.ERROR) logger = logging.getLogger(__name__) @@ -88,7 +90,7 @@ def get_param_groups_by_weight_decay(module): return weight_decay_params, no_weight_decay_params -# SMP modification: Define smp.step. Return any tensors needed outside. +# smdistributed: Define smp.step. Return any tensors needed outside. @smp.step def train_step(model, optimizer, input_ids, attention_mask, args): if args.logits_output: @@ -96,306 +98,22 @@ def train_step(model, optimizer, input_ids, attention_mask, args): loss = output["loss"] else: loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"] - if args.fp16: - optimizer.backward(loss, update_master_grads=False) - else: - model.backward(loss) + + model.backward(loss) + if args.logits_output: return output + return loss -# SMP modification: Define smp.step. Return any tensors needed outside. +# smdistributed: Define smp.step. Return any tensors needed outside. @smp.step def test_step(model, input_ids, attention_mask): loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"] return loss -def save_ckptsum(args, model, optimizer, filename): - results = collections.defaultdict(dict) - model_result = collections.defaultdict(dict) - - if args.fp16: - from fp16.fp16util import register_optimizer_hooks - - register_optimizer_hooks(model) - - def _get_optimizer_result(optimizer_states): - _optimizer_result = collections.defaultdict(dict) - for param_idx, state in optimizer_states.items(): - for key, val in state.items(): - if isinstance(val, torch.Tensor): - _optimizer_result["tensors"][f"{param_idx}_{key}"] = torch.sum(val) - else: - _optimizer_result["scalars"][f"{param_idx}_{key}"] = val - return _optimizer_result - - if not args.shard_optimizer_state: - optimizer_result = _get_optimizer_result(optimizer.local_state_dict()["state"]) - else: - local_state_dict = optimizer.local_state_dict()["state"] - if smp.rdp_rank() == 0: - optimizer_result = [] - for partial_local_state_dict in local_state_dict: - optimizer_result.append(_get_optimizer_result(partial_local_state_dict)) - - for param_name, param in model.local_state_dict().items(): - if isinstance(param, torch.Tensor): - model_result["tensors"][param_name] = torch.sum(param) - else: - model_result["scalars"][param_name] = param - - if smp.rdp_rank() == 0: - results["optimizer"] = optimizer_result - results["model"] = model_result - smp.save(results, filename) - - -def load_and_verify_ckptsum(args, model, optimizer, filename): - results = smp.load(filename) - optimizer_result = ( - results["optimizer"] - if not args.shard_optimizer_state - else results["optimizer"][smp.rdp_rank()] - ) - model_result = results["model"] - - def opt_check_fn(mod, opt): - loaded_opt_states = ( - opt.orig_state_dict()["state"] - if args.shard_optimizer_state - else opt.local_state_dict()["state"] - ) - for param_idx, state in loaded_opt_states.items(): - for key, val in state.items(): - if isinstance(val, torch.Tensor): - assert torch.isclose( - torch.sum(val), optimizer_result["tensors"][f"{param_idx}_{key}"] - ), f"mismatch for param_idx: {param_idx}, key is {key}" - else: - assert ( - val == optimizer_result["scalars"][f"{param_idx}_{key}"] - ), f"mismatch for param_idx: {param_idx}, key is {key}" - print("Optimizer save/load check passed successfully") - - def model_check_fn(mod, opt): - for param_name, param in mod.local_state_dict().items(): - if isinstance(param, torch.Tensor): - assert torch.isclose( - torch.sum(param), model_result["tensors"][param_name] - ), f"mismatch for param_name: {param_name}" - else: - assert ( - param == model_result["scalars"][param_name] - ), f"mismatch for param_name: {param_name}" - print("Model save/load check passed successfully") - - model.register_post_partition_hook(model_check_fn) - model.register_post_step_hook(opt_check_fn) - - -def save( - output_save_file, - model, - optimizer, - lr_scheduler, - model_config, - num_params, - total_steps, - curr_train_path_index, - args, - partial=True, - translate_to_hf=False, - seq_length=1024, - batch_idx=0, -): - save_fn = save_fp16_optimizer - save_dict = { - "cli_args": args.__dict__, - "num_params": num_params, - "total_steps": total_steps, - "curr_train_path_index": curr_train_path_index, - "model_config": model_config, - "batch_idx": batch_idx, - } - - if lr_scheduler is not None: - save_dict["lr_scheduler"] = lr_scheduler.state_dict() - if partial: - # SMP modification: check if using optimizer state sharding or tensor parallelism - if args.gather_if_shard > 0 or smp.rdp_rank() == 0: - # if not gather the opt checkpoint, only save the model for rdp rank 0 - save_dict["model"] = model.local_state_dict() - else: - model_state_dict = model.state_dict(gather_to_rank0=True) - if smp.rank() == 0: - save_dict["model"] = ( - translate_state_dict_to_hf_gpt2(model_state_dict, seq_length) - if translate_to_hf - else model_state_dict - ) - - if args.fp16: - if not partial and args.skip_full_optimizer: - print("Skipping saving the final optimizer state") - else: - if args.shard_optimizer_state == 0 or partial: - save_dict["optimizer"] = save_fn(args, model, optimizer, partial=partial) - else: - print( - "Saving the full optimizer state does not work with shard_optimizer_state > 0! Skipping..." - ) - else: - # fp32 - if partial: - save_dict["optimizer"] = optimizer.local_state_dict() - else: - if not args.skip_full_optimizer: - save_dict["optimizer"] = optimizer.state_dict() - else: - print("Skipping saving of full optimizer state") - - # SMP modification: criteria for checkpointing the zeroth rank for - # pipeline parallelism, checkpointing the zeroth reduced data parallel - # rank for tensor parallelism, and preventing checkpointing if optimizer - # state sharding is enabled - if not args.gather_if_shard or (smp.rdp_rank() == 0 and partial) or smp.rank() == 0: - smp.save(save_dict, output_save_file, partial=partial, v3=not args.gather_if_shard) - - print(f"Finished checkpointing after {total_steps} steps: {output_save_file}") - - -def load_model_and_optimizer( - output_dir, - model, - optimizer, - lr_scheduler, - partial, - args, - translate_from_hf=False, - seq_length=1024, - load_model=True, - load_optimizer=True, - num_params=0, -): - # Find longest-trained checkpoint - re_pattern = f"trained_gpt_nparams-{num_params}_steps-(?P\d+)\.pt" - if partial: - re_pattern += "_(?P\d+)" - else: - re_pattern += "$" - - ckpt_paths = sorted( - [ - (int(re.match(re_pattern, p).group("total_steps")), os.path.join(output_dir, p)) - for p in os.listdir(output_dir) - if re.match(re_pattern, p) - ], - reverse=True, - ) - if not ckpt_paths: - raise Exception( - f'No checkpoints could be found in "{output_dir}". Candidates: {os.listdir(output_dir)}' - ) - - local_ckpt_path = ckpt_paths[0][1] - - if partial: - # need to pass prefix without ranks to smp - local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt" - - if args.gather_if_shard > 0: - # Should expect v2 checkpoint here - checkpoint = smp.load(local_ckpt_path, partial=partial) - else: - # Loading separately for model and opt - checkpoint = torch.load(f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_0") - if smp.rdp_rank() != 0: - opt_checkpoint = torch.load( - f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_{smp.rdp_rank()}" - ) - - if load_model: - checkpointed_model = ( - translate_hf_state_dict_to_smdistributed(checkpoint["model"], seq_length) - if translate_from_hf - else checkpoint["model"] - ) - model.load_state_dict(checkpointed_model, same_partition_load=args.same_partition_load > 0) - if lr_scheduler is not None: - lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - - if load_optimizer: - # Loading loss scale eagerly - opt_state_dict = checkpoint["optimizer"] - optimizer.loss_scaler = opt_state_dict["loss_scaler"] - optimizer.loss_scaler.model = model - optimizer.dynamic_loss_scale = opt_state_dict["dynamic_loss_scale"] - optimizer.overflow = opt_state_dict["overflow"] - optimizer.first_closure_call_this_step = opt_state_dict["first_closure_call_this_step"] - - def opt_load_hook(mod, opt): - load_fn = load_fp16_optimizer - if args.fp16: - if not partial and args.skip_full_optimizer: - print( - "Skipping loading the final optimizer state, and reloading master_params from model_params" - ) - opt.reload_model_params() - else: - load_fn(args, mod, opt, checkpoint, partial=partial) - else: - # fp32 - if not partial and args.skip_full_optimizer: - print("Skipping loading the final optimizer state") - else: - opt.load_state_dict(checkpoint["optimizer"]) - - model.register_post_step_hook(opt_load_hook) - - print(f'Loaded model from "{local_ckpt_path}"') - - batch_idx = 0 - if "batch_idx" in checkpoint: - batch_idx = checkpoint["batch_idx"] - - return ( - model, - optimizer, - checkpoint["total_steps"], - checkpoint["curr_train_path_index"], - batch_idx, - ) - - -def delete_oldest_ckpt(args, delete_on_rank0_only=False): - to_delete = smp.rank() == 0 if delete_on_rank0_only else smp.local_rank() == 0 - if to_delete: - re_pattern = "trained_gpt_nparams-(?P\d+)_steps-(?P\d+)\.pt" - - # partial - re_pattern += "_(?P\d+)_(?P\d+)" - - paths_per_step = collections.defaultdict(list) - - for p in os.listdir(args.checkpoint_dir): - if re.match(re_pattern, p): - step = int(re.match(re_pattern, p).group("total_steps")) - path = os.path.join(args.checkpoint_dir, p) - paths_per_step[step].append(path) - - if paths_per_step: - oldest_step = sorted(paths_per_step.keys())[0] - num_parts = len(paths_per_step[oldest_step]) - if len(paths_per_step) > args.num_kept_checkpoints: - # delete oldest step - for p in paths_per_step[oldest_step]: - os.remove(p) - # else We still haven't reached maximum number of checkpoints -- no need to delete older ones - return None - - def eval_model(model, dataloader, num_batches, use_wiki_data): model = model.eval() n_batches = 0 @@ -437,6 +155,8 @@ def train( total_steps, args, ): + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="before train step") model.train() if args.parallel_proc_data_processing: pool = ProcessPoolExecutor(1) @@ -524,6 +244,9 @@ def train( to_save = {"loss": [], "val_loss": []} loss_metric = 0 + def grad_accumulation_boundary(batch_idx): + return batch_idx % args.gradient_accumulation == args.gradient_accumulation - 1 + def should_record(): # only record the ranks that in the tp group that contains global rank 0 if smp.tp_size() > 1: @@ -587,10 +310,8 @@ def should_record(): step_start = time.time() - if args.fp16: - optimizer.zero_grad(set_grads_to_None=True) - else: - optimizer.zero_grad() + if grad_accumulation_boundary(batch_idx - 1): + optimizer.zero_grad(set_to_none=True) if args.logits_output: train_output = train_step(model, optimizer, input_ids, attention_mask, args) @@ -604,33 +325,43 @@ def should_record(): # Return value, loss_mb is a StepOutput object loss_mb = train_step(model, optimizer, input_ids, attention_mask, args) - # SMP modification: Average the loss across microbatches. + # smdistributed: Average the loss across microbatches. loss = loss_mb.reduce_mean() if not args.validation_freq: loss_metric = loss.item() + + if args.enable_memory_profiling > 0: + memory_status_cpu("After_train_step_cpu") + memory_status(msg="After_train_step") - if args.fp16: - optimizer.update_master_grads() - optimizer.clip_master_grads(args.grad_clip) - optimizer.step() - overflow = optimizer.overflow - else: - optimizer.step() + if args.clean_cache > 0: + # empty the cache to avoid OOM + torch.cuda.empty_cache() - if not (args.fp16 and overflow): - lr_scheduler.step() - if args.enable_memory_profiling > 0: - memory_status(msg="After_opt_step") + if grad_accumulation_boundary(batch_idx): + if args.fp16: + optimizer.clip_master_grads(args.grad_clip) + + optimizer.step() + if not (args.fp16 and optimizer.overflow): + lr_scheduler.step() + + if args.enable_memory_profiling > 0: + memory_status(msg="After_opt_step") total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start sample_processed = input_ids.shape[0] * dp_size throughput = sample_processed / step_time + tokens_per_gpu = input_ids.shape[0] * input_ids.shape[1] + + # Based on the formula in https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ + tflops_per_gpu = 8 * num_params * tokens_per_gpu / step_time / 1e12 if smp.rank() == 0 and not total_steps % args.logging_freq: print( - f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec" + f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec, TFLOPS/GPU: {tflops_per_gpu}" ) # evaluate on validation @@ -656,37 +387,26 @@ def should_record(): # checkpoint if not (total_steps % args.checkpoint_freq): - base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" - out_path = os.path.join(args.checkpoint_dir, base_path) - total_ckpts = total_steps // args.checkpoint_freq - - delete_oldest_ckpt(args, delete_on_rank0_only=args.use_fsx > 0) - - # save_or_verify_ckptsum if this is the last checkpoint - if (args.save_or_verify_ckptsum and total_steps >= args.max_steps) or ( - (total_ckpts + 1) * args.checkpoint_freq - ) > args.max_steps: - # Save optimizer and model tensor sums and scalars before saving - save_ckptsum( - args, - model, - optimizer, - filename=os.path.join(args.model_dir, "saved_partial_sum"), - ) - - save( - out_path, - model, - optimizer, - lr_scheduler, - model_config, - num_params, - total_steps, - curr_train_path_index, - args, + user_content = { + "cli_args": args.__dict__, + "num_params": num_params, + "total_steps": total_steps, + "start_train_path_index": curr_train_path_index, + "model_config": model_config, + "start_batch_index": batch_idx+1, + } + # to reconstruct the full model + if args.sharded_data_parallel_degree > 1: + user_content["buffer_names"] = get_buffer_names(model) + user_content["param_shapes"] = get_param_shapes(model, optimizer) + user_content["lr_scheduler"] = lr_scheduler.state_dict() + smp.save_checkpoint(args.checkpoint_dir, + tag=f"total_steps{total_steps}", partial=True, - batch_idx=batch_idx + 1, - ) + model=model, + optimizer=optimizer, + user_content=user_content, + num_kept_partial_checkpoints=args.num_kept_checkpoints) if args.logits_output: to_save["loss"].append(loss.item()) @@ -747,9 +467,8 @@ def parse_args(): opt_grp.add_argument("--same_seed", type=int, default=0) opt_grp.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) opt_grp.add_argument("--fp16", default=0, type=int, help="automatic mixed precision training") - opt_grp.add_argument( - "--fp32_grad_accumulation", default=0, type=int, help="Enable FP32 Grad accumulation" - ) + opt_grp.add_argument("--bf16", default=0, type=int, help="automatic mixed precision training") + opt_grp.add_argument("--sharded_data_parallel_degree", default=1, type=int) opt_grp.add_argument("--grad_clip", default=1.0, type=float, help="gradient clipping") opt_grp.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") opt_grp.add_argument( @@ -802,12 +521,6 @@ def parse_args(): default=0, help="Enabling this will save a combined model only at the end", ) - io_grp.add_argument( - "--skip_full_optimizer", - type=int, - default=1, - help="Disabling this will also save the full optimizer state", - ) io_grp.add_argument("--load_partial", type=int, default=0, help="Load from partial checkpoints") io_grp.add_argument("--load_full", type=int, default=0, help="Load from full checkpoints") io_grp.add_argument( @@ -820,7 +533,7 @@ def parse_args(): title="model", description="arguments to describe model configuration" ) model_grp.add_argument("--max_context_width", type=int, default=1024) - model_grp.add_argument("--vocab_size", type=int, default=50257) + model_grp.add_argument("--vocab_size", type=int, default=50264) model_grp.add_argument("--hidden_width", type=int, default=768) model_grp.add_argument("--num_layers", type=int, default=12) model_grp.add_argument("--num_heads", type=int, default=12) @@ -829,9 +542,11 @@ def parse_args(): model_grp.add_argument("--attn_pdrop", type=float, default=0.1) model_grp.add_argument("--summary_first_pdrop", type=float, default=0.1) model_grp.add_argument("--use_adamw", type=int, default=0, help="Use adamw optimizer") + model_grp.add_argument("--use_distributed_transformer", type=int, default=1, help="Use distributed transformer") + model_grp.add_argument("--checkpoint_sublayers", type=int, default=0, help="Apply activation checkpointing to submodules of each transformer layer") smp_grp = parser.add_argument_group(title="smp", description="smp") - smp_grp.add_argument("--tensor_parallel_degree", type=int, default=8) + smp_grp.add_argument("--tensor_parallel_degree", type=int, default=1) smp_grp.add_argument("--pipeline_parallel_degree", type=int, default=1) smp_grp.add_argument("--microbatches", type=int, default=1) smp_grp.add_argument("--active_microbatches", type=int, default=None) @@ -849,7 +564,9 @@ def parse_args(): smp_grp.add_argument("--skip_tracing", type=int, default=0) smp_grp.add_argument("--query_key_layer_scaling", type=int, default=1) smp_grp.add_argument("--fused_softmax", type=int, default=1) + smp_grp.add_argument("--fused_dropout", type=int, default=0) smp_grp.add_argument("--fused_bias_gelu", type=int, default=1) + smp_grp.add_argument("--gradient_accumulation", type=int, default=1) parser.add_argument( "--num_kept_checkpoints", @@ -882,7 +599,10 @@ def parse_args(): help="evenly distribute layers across the partitions", ) parser.add_argument( - "--match_weights", type=int, default=0, help="Get weights from the original model" + "--partition_assignment", + type=str, + default="", + help="number of transformer layers assigned to each partition", ) parser.add_argument( "--preserve_np_state", @@ -956,21 +676,28 @@ def parse_args(): ci_grp.add_argument("--time_to_train", type=int, help="time to train threshold") ci_grp.add_argument("--throughput", type=float, help="throughput threshold") ci_grp.add_argument("--loss", type=float, help="loss threshold") - ci_grp.add_argument( - "--save_or_verify_ckptsum", default=False, action="store_true", help="Whether to save sum" - ) - args, _ = parser.parse_known_args() return args +def compute_num_params(model): + num_params = 0 + seen = set() + for p in model.parameters(): + if p not in seen: + seen.add(p) + if hasattr(p, "ds_shape"): + num_params += np.prod(p.ds_shape) + else: + num_params += np.prod(p.size()) + + return num_params def main(): args = parse_args() - if args.shard_optimizer_state > 0 and not args.skip_full_optimizer: - raise ValueError( - "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding." - ) + if args.partition_assignment != "" and args.manual_partition == 0: + print("[Warning] partition_assignment is set, enable manual_partition") + args.manual_partition = 1 # any value here is overriden by the config set in notebook when launching the sagemaker job smp_config = { @@ -978,22 +705,21 @@ def main(): "tensor_parallel_degree": args.tensor_parallel_degree, "pipeline_parallel_degree": args.pipeline_parallel_degree, "microbatches": args.microbatches, - # if activation_checkpointing true checkpoints transformer layers below - "checkpoint_attentions": False if args.activation_checkpointing else True, "shard_optimizer_state": args.shard_optimizer_state > 0, "prescaled_batch": args.prescaled_batch > 0, - "_match_weights": args.match_weights > 0, - "fp16_params": args.fp16 > 0, + "fp16": args.fp16 > 0, + "bf16": args.bf16 > 0, "offload_activations": args.offload_activations > 0, + "delayed_parameter_initialization": args.delayed_param > 0, "optimize": args.optimize, "placement_strategy": args.placement_strategy, "activation_loading_horizon": args.activation_loading_horizon, "skip_tracing": args.skip_tracing > 0, "auto_partition": False if args.manual_partition else True, "default_partition": 0, - "_fp32_grad_accumulation": args.fp32_grad_accumulation > 0, "static_mode": args.static_mode > 0, "fast_mode": args.fast_mode > 0, + "sharded_data_parallel_degree": args.sharded_data_parallel_degree, } if args.active_microbatches is not None: smp_config["active_microbatches"] = args.active_microbatches @@ -1011,10 +737,11 @@ def main(): f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints." ) - if smp.local_rank() == 0: - for path in [args.model_dir, args.checkpoint_dir]: - if not os.path.exists(path): - os.makedirs(path, exist_ok=True) + if args.partition_assignment != "": + partition_assignment = args.partition_assignment.split(",") + assert ( + len(partition_assignment) == smp.pp_size() + ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}" model_config = GPT2Config( vocab_size=args.vocab_size, @@ -1043,38 +770,44 @@ def main(): # the following improves start-up time by skipping proper initialization # of weights in the original model. this is not a problem because DistributedModel - # will override those weights anyway when tensor_parallel_degree > 1. - if smp.tp_size() > 1 and args.match_weights < 1: + # will override those weights anyway when we use distributed transformer. + if args.use_distributed_transformer > 0: from transformers.modeling_utils import PreTrainedModel PreTrainedModel.init_weights = lambda x: None set_seed(args.seed) - if args.fp16: - torch.set_default_dtype(torch.float16) - with smp.tensor_parallelism( - enabled=smp.tp_size() > 1, + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="before model creation") + + if args.fp16 and args.bf16: + raise ValueError("FP16 and BF16 cannot be simultaneously enabled.") + elif args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.get_default_dtype() + + with smp.model_creation( + tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, + dtype=dtype, attention_in_fp32=args.attention_in_fp32 > 0, - query_key_layer_scaling=args.query_key_layer_scaling > 0, + query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, + fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, - ): - with smp.delay_param_initialization( - enabled=(smp.tp_size() > 1 and args.match_weights < 1 and args.delayed_param > 0) ): model = AutoModelForCausalLM.from_config(model_config) - - torch.set_default_dtype(torch.float32) - - if args.fp16: - model = FP16_Module(model) + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="after model creation") - num_params = sum([np.prod(p.size()) for p in model.parameters()]) + num_params = compute_num_params(model) if smp.rank() == 0: print(f"# total parameters: {num_params}") - # SMP modification: Set the device to the GPU ID used by the current process. + # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") @@ -1083,29 +816,34 @@ def main(): # Set seed by tp_rank to prevent weights from being the same on different tp_ranks set_seed(args.seed + smp.tp_rank()) - # SMP modification: Use the DistributedModel container to provide the model + # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. - if args.fp16: - torch.set_default_dtype(torch.float16) - model = smp.DistributedModel(model, trace_device="gpu") - - if args.fp16: - m = model.module - else: - m = model + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="before dist model creation") + model = smp.DistributedModel(model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation) + if args.enable_memory_profiling > 0: + memory_status_cpu(msg="after dist model creation") - if smp.tp_size() > 1: - transformer_layers = m.module.module.transformer.seq_layers + m = model.get_module() + if args.use_distributed_transformer > 0: + transformer_layers = m.transformer.seq_layers else: - transformer_layers = m.module.module.transformer.h + transformer_layers = m.transformer.h if args.manual_partition: print(f"Manual partition enabled") - # evenly distribute layers across all partitions - div, rem = divmod(args.num_layers, smp.pp_size()) - get_num_layers = lambda x: (div + 1 if x >= smp.pp_size() - rem else div) + if args.partition_assignment != "": + get_num_layers = lambda x: int(partition_assignment[x]) + total_layers = sum([get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())]) + assert ( + total_layers == args.num_layers + ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}" + else: + # evenly distribute layers across all partitions + div, rem = divmod(args.num_layers, smp.pp_size()) + get_num_layers = lambda x: (div + 1 if x >= smp.pp_size() - rem else div) assignments = [] for pp_rank in range(smp.pp_size()): nl = get_num_layers(pp_rank) @@ -1115,13 +853,7 @@ def main(): for i, c in enumerate(transformer_layers.children()): smp.set_partition(c, assignments[i]) - torch.set_default_dtype(torch.float32) - - iter_model = model - # Build parameter groups (weight decay and non-decay). - while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): - iter_model = iter_model.module - param_groups = get_param_groups_by_weight_decay(iter_model) + param_groups = get_param_groups_by_weight_decay(m) if args.use_adamw > 0: optimizer = optim.AdamW( @@ -1133,33 +865,29 @@ def main(): ) if args.activation_checkpointing: - kwargs = {} - if isinstance(transformer_layers, nn.Sequential): - kwargs["pack_args_as_tuple"] = True - kwargs["strategy"] = args.activation_strategy - smp.set_activation_checkpointing(transformer_layers, **kwargs) + if args.use_distributed_transformer or smp.tp_size() > 1: + if args.checkpoint_sublayers: + for c in transformer_layers.children(): + smp.set_activation_checkpointing(c.attention) + smp.set_activation_checkpointing(c.output) + else: + smp.set_activation_checkpointing(transformer_layers, strategy=args.activation_strategy) else: for c in transformer_layers.children(): - smp.set_activation_checkpointing(c) - - if args.fp16: - optimizer = FP16_Optimizer( - model, - optimizer, - static_loss_scale=None, - dynamic_loss_scale=True, - use_smp=True, - dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, - params_have_main_grad=args.fp32_grad_accumulation > 0, - shard_optimizer_state=args.shard_optimizer_state > 0, - ) + if args.checkpoint_sublayers: + smp.set_activation_checkpointing(c.attn) + smp.set_activation_checkpointing(c.mlp) + else: + smp.set_activation_checkpointing(c) - optimizer = smp.DistributedOptimizer(optimizer) + optimizer = smp.DistributedOptimizer( + optimizer, + static_loss_scale=None, + dynamic_loss_scale=True, + dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, + ) lr_scheduler = get_learning_rate_scheduler(optimizer, args) - if args.fp16: - model.register_post_step_hook(lambda model, optimizer: optimizer.init_master_params()) - if args.enable_memory_profiling > 0: model.register_post_partition_hook( lambda model, optimizer: memory_status(msg="After_partition") @@ -1174,25 +902,13 @@ def main(): ) partial = not args.load_full path = args.checkpoint_dir if partial else args.model_dir - translate_from_hf = not partial - model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer( - path, - model, - optimizer, - lr_scheduler, - partial, - args, - translate_from_hf=translate_from_hf, - seq_length=args.max_context_width, - load_model=True, - load_optimizer=args.load_partial > 0, - num_params=num_params, - ) - if args.save_or_verify_ckptsum: - filename = "saved_sum" if args.load_full else "saved_partial_sum" - load_and_verify_ckptsum( - args, model, optimizer, filename=os.path.join(args.model_dir, filename) - ) + tag = None if partial else "fullmodel.pt" + user_content = smp.resume_from_checkpoint(path, tag=tag, partial=partial) + total_steps = user_content["total_steps"] if partial else 0 + start_train_path_index = user_content.get("start_train_path_index", 0) + start_batch_index = user_content.get("start_batch_index", 0) + if "lr_scheduler" in user_content: + lr_scheduler.load_state_dict(user_content["lr_scheduler"]) else: total_steps = 0 start_train_path_index = 0 @@ -1223,28 +939,30 @@ def main(): if args.save_final_full_model: # saves full model at the end - - base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" - out_path = os.path.join(args.model_dir, base_path) - if args.save_or_verify_ckptsum: - # Save optimizer and model tensor sums and scalars before saving - save_ckptsum(args, model, optimizer, filename=os.path.join(args.model_dir, "saved_sum")) - - if smp.rdp_rank() == 0: - save( - out_path, - model, - optimizer, - lr_scheduler, - model_config, - num_params, - total_steps, - -1, - args, - partial=False, - translate_to_hf=smp.tp_size() > 1, - seq_length=args.max_context_width, - ) + user_content = { + "cli_args": args.__dict__, + "num_params": num_params, + "total_steps": total_steps, + "model_config": model_config, + } + if args.sharded_data_parallel_degree > 1: + # When sharded_data_parallel_degree > 1, saving full model is not supported, saving partial instead + # To get the full model, one can use the following API + # > from sharded_data_parallel_checkpoint import get_full_state_dict_from_sharded_data_parallel_checkpoint + # > full_model = get_full_state_dict_from_sharded_data_parallel_checkpoint(args.model_dir, tag=f"sharded_data_parallel_final_full_{num_params}", dtype=torch.float32) + # > if args.use_distributed_transformer > 0: # translate the state_dict to hf format if distributed transformer is used + # > full_model = smp.nn.huggingface.gpt2.translate_state_dict_to_hf_gpt2(full_model, max_seq_len=args.max_context_width) + # Note: the shared parameter will not be reflected so during loading you might need to load with strict=False + user_content["buffer_names"] = get_buffer_names(model) + user_content["param_shapes"] = get_param_shapes(model, optimizer) + smp.save_checkpoint(args.model_dir, + tag=f"sharded_data_parallel_final_full_{num_params}", + partial=True, + model=model, + optimizer=optimizer, + user_content=user_content) + else: + smp.save_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False, model=model, user_content=user_content) smp.barrier() if smp.rank() == 0: