diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py b/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py old mode 100644 new mode 100755 index 68d5d371f5..f5b7c3461f --- a/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py +++ b/training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py @@ -119,8 +119,8 @@ def save_fp16_optimizer(args, model, optimizer, partial=True): if optimizer.master_params_created: register_optimizer_hooks(model) if partial: - optimizer_state_dict["optimizer_state_dict"] = optimizer.local_state_dict() - if args.shard_optimizer_state: + optimizer_state_dict["optimizer_state_dict"] = optimizer.local_state_dict(gather_if_shard=args.gather_if_shard > 0) + if args.shard_optimizer_state and args.gather_if_shard > 0: if smp.rdp_rank() == 0: print("With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0") gathered_cpu_fp32_from_fp16_groups = [cpu_fp32_from_fp16_groups] @@ -164,8 +164,7 @@ def load_fp16_optimizer(args, model, optimizer, state_dict, partial=True): def hook_fn(model, optimizer): optimizer.load_state_dict(opt_state_dict["optimizer_state_dict"]) if partial: - if args.shard_optimizer_state: - assert isinstance(opt_state_dict["fp32_from_fp16"], list), "Loading with shard_optimizer_state=True must use the checkpoint that was trained with shard_optimizer_state=True!" + if args.shard_optimizer_state and args.gather_if_shard > 0: optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"][smp.rdp_rank()] else: optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"] diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/memory_tracker.py b/training/distributed_training/pytorch/model_parallel/gpt2/memory_tracker.py new file mode 100644 index 0000000000..39a3550ab0 --- /dev/null +++ b/training/distributed_training/pytorch/model_parallel/gpt2/memory_tracker.py @@ -0,0 +1,37 @@ +import smdistributed.modelparallel.torch as smp +import torch + +def memory_status(msg="", reset_max=True, sync=True): + + rank = smp.rank() + tp_rank = smp.tp_rank() + pp_rank = smp.pp_rank() + rdp_rank = smp.rdp_rank() + local_rank = smp.local_rank() + + if sync: + torch.cuda.synchronize() + + if rdp_rank != 0: + return + + alloced = torch.cuda.memory_allocated(device=local_rank) + max_alloced = torch.cuda.max_memory_allocated(device=local_rank) + cached = torch.cuda.memory_reserved(device=local_rank) + max_cached = torch.cuda.max_memory_reserved(device=local_rank) + + # convert to GB for printing + alloced /= 1024**3 + cached /= 1024**3 + max_alloced /= 1024**3 + max_cached /= 1024**3 + + print( + f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', + f'device={local_rank} ' + f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} ' + f'cache {cached:0.4f} max_cached {max_cached:0.4f}' + ) + if reset_max: + torch.cuda.reset_max_memory_cached() + torch.cuda.reset_max_memory_allocated() \ No newline at end of file diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt b/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt index 97ac71eb9c..2a1fd1a21e 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt +++ b/training/distributed_training/pytorch/model_parallel/gpt2/requirements.txt @@ -4,6 +4,5 @@ sagemaker sagemaker-experiments scipy torchnet -transformers==4.4.2 smdebug humanize diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb b/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb index 2aaf88de75..0668dda4dd 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb +++ b/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb @@ -17,7 +17,7 @@ "\n", "This notebook depends on the following files and folders:\n", "\n", - "- `train_gpt_simple.py`: This is an entrypoint script that is passed to the Pytorch estimator in the notebook instructions. This script is responsible for end to end training of the GPT-2 model with SMP. The script has additional comments at places where the SMP API is used.\n", + "- `train_gpt_simple.py`: This is an entrypoint script that is passed to the Hugging Face estimator in the notebook instructions. This script is responsible for end to end training of the GPT-2 model with SMP. The script has additional comments at places where the SMP API is used.\n", "- `fp16`: This folder is used for 16-bit float training, which contains a fp16 optimizer and various fp16 utilities.\n", "- `data_pipeline.py`: This contains the datapipeline function to prepare the training data.\n", "- `learining_rate.py`: This contains the functions for learning rate schedule.\n", @@ -29,7 +29,7 @@ "\n", "- To learn more about the SageMaker model parallelism library, see [Model Parallel Distributed Training with SageMaker Distributed](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html).\n", "\n", - "- To learn more about using the SageMaker Python SDK with Pytorch, see [Using PyTorch with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).\n", + "- To learn more about using the SageMaker Python SDK with PyTorch, see [Using PyTorch with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).\n", "\n", "- To learn more about launching a training job in Amazon SageMaker with your own training image, see [Use Your Own Training Algorithms](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html).\n", "\n", @@ -50,7 +50,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Upgrade Sagemaker SDK to the latest version. \n", + "Upgrade SageMaker SDK to the latest version. \n", "\n", "**NOTE:** This step might require a kernel restart." ] @@ -74,12 +74,10 @@ "%%time\n", "import os\n", "\n", + "import boto3\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", - "from sagemaker.pytorch import PyTorch\n", - "from smexperiments.experiment import Experiment\n", - "from smexperiments.trial import Trial\n", - "import boto3\n", + "from sagemaker.huggingface import HuggingFace\n", "\n", "role = (\n", " get_execution_role()\n", @@ -307,18 +305,27 @@ " \"num_kept_checkpoints\": 5,\n", " \"checkpoint_freq\": 200,\n", " \"logging_freq\": 1,\n", - " \"save_final_full_model\": 0,\n", - " \"manual_partition\": 1,\n", - " \"skip_full_optimizer\": 1,\n", - " \"shard_optimizer_state\": 1,\n", - " \"activation_checkpointing\": 1,\n", - " \"activation_strategy\": \"each\",\n", - " \"optimize\": \"speed\",\n", " \"use_wiki_data\": 1,\n", " # below flag loads model and optimizer state from checkpoint_s3_uri\n", " # 'load_partial': 1,\n", "}\n", "\n", + "# Add parameters required by SMP config.\n", + "# Refer https://sagemaker.readthedocs.io/en/stable/api/training/smd_model_parallel_general.html\n", + "# for details.\n", + "hyperparameters.update(\n", + " {\n", + " \"save_final_full_model\": 0,\n", + " \"manual_partition\": 1,\n", + " \"skip_full_optimizer\": 1,\n", + " \"shard_optimizer_state\": 1,\n", + " \"activation_checkpointing\": 1,\n", + " \"activation_strategy\": \"each\",\n", + " \"optimize\": \"speed\",\n", + "\n", + " }\n", + ")\n", + "\n", "if not running_ci:\n", " # those flags are used when training with the openwebtext dataset\n", " hyperparameters[\"zipped_data\"] = 0\n", @@ -411,47 +418,6 @@ " hyperparameters[k] = v" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set Up SageMaker Studio Experiment\n", - "Create or load [SageMaker Experiment](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) for the example training job. This will create an experiment trial object in SageMaker Studio." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from time import gmtime, strftime\n", - "\n", - "if not running_ci:\n", - " # Specify your experiment name\n", - " experiment_name = \"smp-gpt2\"\n", - " # Specify your trial name\n", - " trial_name = f\"{experiment_name}-trial1\"\n", - "\n", - " all_experiment_names = [exp.experiment_name for exp in Experiment.list()]\n", - " # Load the experiment if it exists, otherwise create\n", - " if experiment_name not in all_experiment_names:\n", - " experiment = Experiment.create(\n", - " experiment_name=experiment_name, sagemaker_boto_client=sm_boto_client\n", - " )\n", - " else:\n", - " experiment = Experiment.load(\n", - " experiment_name=experiment_name, sagemaker_boto_client=sm_boto_client\n", - " )\n", - "\n", - " # Create the trial\n", - " trial = Trial.create(\n", - " trial_name=\"smp-{}-{}\".format(trial_name, strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())),\n", - " experiment_name=experiment.experiment_name,\n", - " sagemaker_boto_client=sm_boto_client,\n", - " )" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -561,9 +527,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Create a SageMaker PyTorch Estimator\n", + "### Create a SageMaker HuggingFace Estimator\n", "\n", - "The following cell constructs a PyTorch estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `train_gpt_simple.py` file." + "The following cell constructs a HuggingFace estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `train_gpt_simple.py` file." ] }, { @@ -578,7 +544,7 @@ " kwargs[\"security_group_ids\"] = [fsx_security_group_id]\n", " kwargs[\"subnets\"] = [fsx_subnet]\n", "\n", - "smp_estimator = PyTorch(\n", + "smp_estimator = HuggingFace(\n", " entry_point=\"train_gpt_simple.py\",\n", " source_dir=os.getcwd(),\n", " role=role,\n", @@ -613,8 +579,9 @@ " }\n", " },\n", " },\n", - " framework_version=\"1.8.1\",\n", - " py_version=\"py36\",\n", + " pytorch_version=\"1.10\",\n", + " transformers_version=\"4.17\",\n", + " py_version=\"py38\",\n", " output_path=s3_output_location,\n", " checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,\n", " checkpoint_local_path=hyperparameters[\"checkpoint-dir\"] if use_fsx else None,\n", @@ -643,18 +610,7 @@ }, "outputs": [], "source": [ - "if not running_ci:\n", - " smp_estimator.fit(\n", - " inputs=data_channels,\n", - " experiment_config={\n", - " \"ExperimentName\": experiment.experiment_name,\n", - " \"TrialName\": trial.trial_name,\n", - " \"TrialComponentDisplayName\": \"Training\",\n", - " },\n", - " logs=True,\n", - " )\n", - "else:\n", - " smp_estimator.fit(inputs=data_channels, logs=True)" + "smp_estimator.fit(inputs=data_channels, logs=True)" ] }, { diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py b/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py index 16e7cd75df..5c44c90ed0 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py +++ b/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py @@ -88,7 +88,7 @@ def get_param_groups_by_weight_decay(module): return weight_decay_params, no_weight_decay_params -# smdistributed: Define smp.step. Return any tensors needed outside. +# SMP modification: Define smp.step. Return any tensors needed outside. @smp.step def train_step(model, optimizer, input_ids, attention_mask, args): if args.logits_output: @@ -105,7 +105,7 @@ def train_step(model, optimizer, input_ids, attention_mask, args): return loss -# smdistributed: Define smp.step. Return any tensors needed outside. +# SMP modification: Define smp.step. Return any tensors needed outside. @smp.step def test_step(model, input_ids, attention_mask): loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"] @@ -223,7 +223,10 @@ def save( if lr_scheduler is not None: save_dict["lr_scheduler"] = lr_scheduler.state_dict() if partial: - save_dict["model"] = model.local_state_dict() + # SMP modification: check if using optimizer state sharding or tensor parallelism + if args.gather_if_shard > 0 or smp.rdp_rank() == 0: + # if not gather the opt checkpoint, only save the model for rdp rank 0 + save_dict["model"] = model.local_state_dict() else: model_state_dict = model.state_dict(gather_to_rank0=True) if smp.rank() == 0: @@ -253,8 +256,12 @@ def save( else: print("Skipping saving of full optimizer state") - if (smp.rdp_rank() == 0 and partial) or smp.rank() == 0: - smp.save(save_dict, output_save_file, partial=partial) + # SMP modification: criteria for checkpointing the zeroth rank for + # pipeline parallelism, checkpointing the zeroth reduced data parallel + # rank for tensor parallelism, and preventing checkpointing if optimizer + # state sharding is enabled + if not args.gather_if_shard or (smp.rdp_rank() == 0 and partial) or smp.rank() == 0: + smp.save(save_dict, output_save_file, partial=partial, v3=not args.gather_if_shard) print(f"Finished checkpointing after {total_steps} steps: {output_save_file}") @@ -298,7 +305,16 @@ def load_model_and_optimizer( # need to pass prefix without ranks to smp local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt" - checkpoint = smp.load(local_ckpt_path, partial=partial) + if args.gather_if_shard > 0: + # Should expect v2 checkpoint here + checkpoint = smp.load(local_ckpt_path, partial=partial) + else: + # Loading separately for model and opt + checkpoint = torch.load(f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_0") + if smp.rdp_rank() != 0: + opt_checkpoint = torch.load( + f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_{smp.rdp_rank()}" + ) if load_model: checkpointed_model = ( @@ -353,28 +369,30 @@ def opt_load_hook(mod, opt): ) -def delete_oldest_ckpt(args): - re_pattern = "trained_gpt_nparams-(?P\d+)_steps-(?P\d+)\.pt" +def delete_oldest_ckpt(args, delete_on_rank0_only=False): + to_delete = smp.rank() == 0 if delete_on_rank0_only else smp.local_rank() == 0 + if to_delete: + re_pattern = "trained_gpt_nparams-(?P\d+)_steps-(?P\d+)\.pt" - # partial - re_pattern += "_(?P\d+)_(?P\d+)" + # partial + re_pattern += "_(?P\d+)_(?P\d+)" - paths_per_step = collections.defaultdict(list) + paths_per_step = collections.defaultdict(list) - for p in os.listdir(args.checkpoint_dir): - if re.match(re_pattern, p): - step = int(re.match(re_pattern, p).group("total_steps")) - path = os.path.join(args.checkpoint_dir, p) - paths_per_step[step].append(path) + for p in os.listdir(args.checkpoint_dir): + if re.match(re_pattern, p): + step = int(re.match(re_pattern, p).group("total_steps")) + path = os.path.join(args.checkpoint_dir, p) + paths_per_step[step].append(path) - if paths_per_step: - oldest_step = sorted(paths_per_step.keys())[0] - num_parts = len(paths_per_step[oldest_step]) + if paths_per_step: + oldest_step = sorted(paths_per_step.keys())[0] + num_parts = len(paths_per_step[oldest_step]) if len(paths_per_step) > args.num_kept_checkpoints: # delete oldest step - for p in paths_per_step[oldest_step]: - os.remove(p) - # else We still haven't reached maximum number of checkpoints -- no need to delete older ones + for p in paths_per_step[oldest_step]: + os.remove(p) + # else We still haven't reached maximum number of checkpoints -- no need to delete older ones return None @@ -551,6 +569,10 @@ def should_record(): print( f"Resuming from saved batch index {start_batch_index}, skipping batch {batch_idx}..." ) + if start_batch_index == len(train_dataloader): + # If saving at the last batch of the file, read from the next file + start_batch_index = 0 + break continue else: start_batch_index = 0 @@ -582,7 +604,7 @@ def should_record(): # Return value, loss_mb is a StepOutput object loss_mb = train_step(model, optimizer, input_ids, attention_mask, args) - # smdistributed: Average the loss across microbatches. + # SMP modification: Average the loss across microbatches. loss = loss_mb.reduce_mean() if not args.validation_freq: loss_metric = loss.item() @@ -594,9 +616,13 @@ def should_record(): overflow = optimizer.overflow else: optimizer.step() + if not (args.fp16 and overflow): lr_scheduler.step() + if args.enable_memory_profiling > 0: + memory_status(msg="After_opt_step") + total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start @@ -634,6 +660,8 @@ def should_record(): out_path = os.path.join(args.checkpoint_dir, base_path) total_ckpts = total_steps // args.checkpoint_freq + delete_oldest_ckpt(args, delete_on_rank0_only=args.use_fsx > 0) + # save_or_verify_ckptsum if this is the last checkpoint if (args.save_or_verify_ckptsum and total_steps >= args.max_steps) or ( (total_ckpts + 1) * args.checkpoint_freq @@ -660,9 +688,6 @@ def should_record(): batch_idx=batch_idx + 1, ) - if smp.local_rank() == 0: - delete_oldest_ckpt(args) - if args.logits_output: to_save["loss"].append(loss.item()) @@ -681,6 +706,7 @@ def should_record(): train_dataloader = dataset_future.result(timeout=None) wait_time = time.time() - s if wait_time > 1: + # TODO if this happens, we should try num_workers>1 in dataloader print( f"[{smp.rank()}] Waited {wait_time} for data loader to be ready. Please check if dataloader performance can be improved to avoid these waits." ) @@ -794,6 +820,7 @@ def parse_args(): title="model", description="arguments to describe model configuration" ) model_grp.add_argument("--max_context_width", type=int, default=1024) + model_grp.add_argument("--vocab_size", type=int, default=50257) model_grp.add_argument("--hidden_width", type=int, default=768) model_grp.add_argument("--num_layers", type=int, default=12) model_grp.add_argument("--num_heads", type=int, default=12) @@ -816,7 +843,13 @@ def parse_args(): smp_grp.add_argument("--static_mode", type=int, default=0) smp_grp.add_argument("--delayed_param", type=int, default=0) smp_grp.add_argument("--same_partition_load", type=int, default=0) - smp_grp.add_argument("--attention_in_fp32", type=int, default=1) + smp_grp.add_argument("--attention_in_fp32", type=int, default=0) + smp_grp.add_argument("--placement_strategy", type=str, default="cluster") + smp_grp.add_argument("--activation_loading_horizon", type=int, default=4) + smp_grp.add_argument("--skip_tracing", type=int, default=0) + smp_grp.add_argument("--query_key_layer_scaling", type=int, default=1) + smp_grp.add_argument("--fused_softmax", type=int, default=1) + smp_grp.add_argument("--fused_bias_gelu", type=int, default=1) parser.add_argument( "--num_kept_checkpoints", @@ -863,6 +896,22 @@ def parse_args(): default=1, help="Running validation only with the last data file for faster speed", ) + parser.add_argument( + "--gather_if_shard", + type=int, + default=1, + help="When sharding opt states is enabled, gather the opt checkpoint to rdp rank 0 during saving", + ) + parser.add_argument( + "--clean_cache", + type=int, + default=0, + help="Clean torch reserved memory at he end of every step", + ) + parser.add_argument("--use_fsx", type=int, default=0, help="Using FSx for checkpointing") + parser.add_argument( + "--enable_memory_profiling", type=int, default=0, help="Enable memory profile" + ) # learning rate lr_grp = parser.add_argument_group( @@ -937,6 +986,9 @@ def main(): "fp16_params": args.fp16 > 0, "offload_activations": args.offload_activations > 0, "optimize": args.optimize, + "placement_strategy": args.placement_strategy, + "activation_loading_horizon": args.activation_loading_horizon, + "skip_tracing": args.skip_tracing > 0, "auto_partition": False if args.manual_partition else True, "default_partition": 0, "_fp32_grad_accumulation": args.fp32_grad_accumulation > 0, @@ -965,9 +1017,8 @@ def main(): os.makedirs(path, exist_ok=True) model_config = GPT2Config( - vocab_size=50257, + vocab_size=args.vocab_size, n_positions=args.max_context_width, - n_ctx=args.max_context_width, n_embd=args.hidden_width, n_layer=args.num_layers, n_head=args.num_heads, @@ -984,7 +1035,7 @@ def main(): summary_proj_to_labels=True, summary_first_dropout=args.summary_first_pdrop, # gradient_checkpointing=args.gradient_checkpointing > 0, - use_cache=True, + use_cache=False, bos_token_id=50256, eos_token_id=50256, return_dict=True, @@ -1003,12 +1054,17 @@ def main(): if args.fp16: torch.set_default_dtype(torch.float16) with smp.tensor_parallelism( - enabled=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0 + enabled=smp.tp_size() > 1, + attention_in_fp32=args.attention_in_fp32 > 0, + query_key_layer_scaling=args.query_key_layer_scaling > 0, + fused_softmax=args.fused_softmax > 0, + fused_bias_gelu=args.fused_bias_gelu > 0, ): with smp.delay_param_initialization( enabled=(smp.tp_size() > 1 and args.match_weights < 1 and args.delayed_param > 0) ): model = AutoModelForCausalLM.from_config(model_config) + torch.set_default_dtype(torch.float32) if args.fp16: @@ -1018,7 +1074,7 @@ def main(): if smp.rank() == 0: print(f"# total parameters: {num_params}") - # smdistributed: Set the device to the GPU ID used by the current process. + # SMP modification: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") @@ -1027,7 +1083,7 @@ def main(): # Set seed by tp_rank to prevent weights from being the same on different tp_ranks set_seed(args.seed + smp.tp_rank()) - # smdistributed: Use the DistributedModel container to provide the model + # SMP modification: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. @@ -1081,7 +1137,10 @@ def main(): if isinstance(transformer_layers, nn.Sequential): kwargs["pack_args_as_tuple"] = True kwargs["strategy"] = args.activation_strategy - smp.set_activation_checkpointing(transformer_layers, **kwargs) + smp.set_activation_checkpointing(transformer_layers, **kwargs) + else: + for c in transformer_layers.children(): + smp.set_activation_checkpointing(c) if args.fp16: optimizer = FP16_Optimizer( @@ -1101,6 +1160,11 @@ def main(): if args.fp16: model.register_post_step_hook(lambda model, optimizer: optimizer.init_master_params()) + if args.enable_memory_profiling > 0: + model.register_post_partition_hook( + lambda model, optimizer: memory_status(msg="After_partition") + ) + # load after wrapping model and optimizer with smp Distributed... if args.load_full or args.load_partial: if args.load_partial and args.load_full: