Skip to content

Commit

Permalink
Sagemaker PyTorch Distributed Model Parallel GPT2 Updates (aws#3311)
Browse files Browse the repository at this point in the history
* Update changes pertaining to gpt2 example

* Remove transformers dep as using official Sagemaker DLC

* Indentation fix

* Update training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb

Co-authored-by: Miyoung <[email protected]>

* Update training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py

Co-authored-by: Miyoung <[email protected]>

* Update training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py

Co-authored-by: Miyoung <[email protected]>

* Disambiguate smp parameters from training parameters

* Upadted comments for SMP related modifications

* Remove SM Experiments instance

* Remove unneccesary condition

* pytorch -> huggingface comment

* Update branding

* Comment formatting

Co-authored-by: Suhit Kodgule <[email protected]>
Co-authored-by: Miyoung <[email protected]>
  • Loading branch information
3 people authored May 25, 2022
1 parent 4ccb050 commit ccc102a
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 111 deletions.
7 changes: 3 additions & 4 deletions training/distributed_training/pytorch/model_parallel/gpt2/fp16/fp16.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def save_fp16_optimizer(args, model, optimizer, partial=True):
if optimizer.master_params_created:
register_optimizer_hooks(model)
if partial:
optimizer_state_dict["optimizer_state_dict"] = optimizer.local_state_dict()
if args.shard_optimizer_state:
optimizer_state_dict["optimizer_state_dict"] = optimizer.local_state_dict(gather_if_shard=args.gather_if_shard > 0)
if args.shard_optimizer_state and args.gather_if_shard > 0:
if smp.rdp_rank() == 0:
print("With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0")
gathered_cpu_fp32_from_fp16_groups = [cpu_fp32_from_fp16_groups]
Expand Down Expand Up @@ -164,8 +164,7 @@ def load_fp16_optimizer(args, model, optimizer, state_dict, partial=True):
def hook_fn(model, optimizer):
optimizer.load_state_dict(opt_state_dict["optimizer_state_dict"])
if partial:
if args.shard_optimizer_state:
assert isinstance(opt_state_dict["fp32_from_fp16"], list), "Loading with shard_optimizer_state=True must use the checkpoint that was trained with shard_optimizer_state=True!"
if args.shard_optimizer_state and args.gather_if_shard > 0:
optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"][smp.rdp_rank()]
else:
optimizer.fp32_from_fp16 = opt_state_dict["fp32_from_fp16"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import smdistributed.modelparallel.torch as smp
import torch

def memory_status(msg="", reset_max=True, sync=True):

rank = smp.rank()
tp_rank = smp.tp_rank()
pp_rank = smp.pp_rank()
rdp_rank = smp.rdp_rank()
local_rank = smp.local_rank()

if sync:
torch.cuda.synchronize()

if rdp_rank != 0:
return

alloced = torch.cuda.memory_allocated(device=local_rank)
max_alloced = torch.cuda.max_memory_allocated(device=local_rank)
cached = torch.cuda.memory_reserved(device=local_rank)
max_cached = torch.cuda.max_memory_reserved(device=local_rank)

# convert to GB for printing
alloced /= 1024**3
cached /= 1024**3
max_alloced /= 1024**3
max_cached /= 1024**3

print(
f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
f'device={local_rank} '
f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} '
f'cache {cached:0.4f} max_cached {max_cached:0.4f}'
)
if reset_max:
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ sagemaker
sagemaker-experiments
scipy
torchnet
transformers==4.4.2
smdebug
humanize
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"\n",
"This notebook depends on the following files and folders:\n",
"\n",
"- `train_gpt_simple.py`: This is an entrypoint script that is passed to the Pytorch estimator in the notebook instructions. This script is responsible for end to end training of the GPT-2 model with SMP. The script has additional comments at places where the SMP API is used.\n",
"- `train_gpt_simple.py`: This is an entrypoint script that is passed to the Hugging Face estimator in the notebook instructions. This script is responsible for end to end training of the GPT-2 model with SMP. The script has additional comments at places where the SMP API is used.\n",
"- `fp16`: This folder is used for 16-bit float training, which contains a fp16 optimizer and various fp16 utilities.\n",
"- `data_pipeline.py`: This contains the datapipeline function to prepare the training data.\n",
"- `learining_rate.py`: This contains the functions for learning rate schedule.\n",
Expand All @@ -29,7 +29,7 @@
"\n",
"- To learn more about the SageMaker model parallelism library, see [Model Parallel Distributed Training with SageMaker Distributed](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html).\n",
"\n",
"- To learn more about using the SageMaker Python SDK with Pytorch, see [Using PyTorch with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).\n",
"- To learn more about using the SageMaker Python SDK with PyTorch, see [Using PyTorch with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).\n",
"\n",
"- To learn more about launching a training job in Amazon SageMaker with your own training image, see [Use Your Own Training Algorithms](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html).\n",
"\n",
Expand All @@ -50,7 +50,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Upgrade Sagemaker SDK to the latest version. \n",
"Upgrade SageMaker SDK to the latest version. \n",
"\n",
"**NOTE:** This step might require a kernel restart."
]
Expand All @@ -74,12 +74,10 @@
"%%time\n",
"import os\n",
"\n",
"import boto3\n",
"import sagemaker\n",
"from sagemaker import get_execution_role\n",
"from sagemaker.pytorch import PyTorch\n",
"from smexperiments.experiment import Experiment\n",
"from smexperiments.trial import Trial\n",
"import boto3\n",
"from sagemaker.huggingface import HuggingFace\n",
"\n",
"role = (\n",
" get_execution_role()\n",
Expand Down Expand Up @@ -307,18 +305,27 @@
" \"num_kept_checkpoints\": 5,\n",
" \"checkpoint_freq\": 200,\n",
" \"logging_freq\": 1,\n",
" \"save_final_full_model\": 0,\n",
" \"manual_partition\": 1,\n",
" \"skip_full_optimizer\": 1,\n",
" \"shard_optimizer_state\": 1,\n",
" \"activation_checkpointing\": 1,\n",
" \"activation_strategy\": \"each\",\n",
" \"optimize\": \"speed\",\n",
" \"use_wiki_data\": 1,\n",
" # below flag loads model and optimizer state from checkpoint_s3_uri\n",
" # 'load_partial': 1,\n",
"}\n",
"\n",
"# Add parameters required by SMP config.\n",
"# Refer https://sagemaker.readthedocs.io/en/stable/api/training/smd_model_parallel_general.html\n",
"# for details.\n",
"hyperparameters.update(\n",
" {\n",
" \"save_final_full_model\": 0,\n",
" \"manual_partition\": 1,\n",
" \"skip_full_optimizer\": 1,\n",
" \"shard_optimizer_state\": 1,\n",
" \"activation_checkpointing\": 1,\n",
" \"activation_strategy\": \"each\",\n",
" \"optimize\": \"speed\",\n",
"\n",
" }\n",
")\n",
"\n",
"if not running_ci:\n",
" # those flags are used when training with the openwebtext dataset\n",
" hyperparameters[\"zipped_data\"] = 0\n",
Expand Down Expand Up @@ -411,47 +418,6 @@
" hyperparameters[k] = v"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set Up SageMaker Studio Experiment\n",
"Create or load [SageMaker Experiment](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) for the example training job. This will create an experiment trial object in SageMaker Studio."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from time import gmtime, strftime\n",
"\n",
"if not running_ci:\n",
" # Specify your experiment name\n",
" experiment_name = \"smp-gpt2\"\n",
" # Specify your trial name\n",
" trial_name = f\"{experiment_name}-trial1\"\n",
"\n",
" all_experiment_names = [exp.experiment_name for exp in Experiment.list()]\n",
" # Load the experiment if it exists, otherwise create\n",
" if experiment_name not in all_experiment_names:\n",
" experiment = Experiment.create(\n",
" experiment_name=experiment_name, sagemaker_boto_client=sm_boto_client\n",
" )\n",
" else:\n",
" experiment = Experiment.load(\n",
" experiment_name=experiment_name, sagemaker_boto_client=sm_boto_client\n",
" )\n",
"\n",
" # Create the trial\n",
" trial = Trial.create(\n",
" trial_name=\"smp-{}-{}\".format(trial_name, strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())),\n",
" experiment_name=experiment.experiment_name,\n",
" sagemaker_boto_client=sm_boto_client,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -561,9 +527,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a SageMaker PyTorch Estimator\n",
"### Create a SageMaker HuggingFace Estimator\n",
"\n",
"The following cell constructs a PyTorch estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `train_gpt_simple.py` file."
"The following cell constructs a HuggingFace estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `train_gpt_simple.py` file."
]
},
{
Expand All @@ -578,7 +544,7 @@
" kwargs[\"security_group_ids\"] = [fsx_security_group_id]\n",
" kwargs[\"subnets\"] = [fsx_subnet]\n",
"\n",
"smp_estimator = PyTorch(\n",
"smp_estimator = HuggingFace(\n",
" entry_point=\"train_gpt_simple.py\",\n",
" source_dir=os.getcwd(),\n",
" role=role,\n",
Expand Down Expand Up @@ -613,8 +579,9 @@
" }\n",
" },\n",
" },\n",
" framework_version=\"1.8.1\",\n",
" py_version=\"py36\",\n",
" pytorch_version=\"1.10\",\n",
" transformers_version=\"4.17\",\n",
" py_version=\"py38\",\n",
" output_path=s3_output_location,\n",
" checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,\n",
" checkpoint_local_path=hyperparameters[\"checkpoint-dir\"] if use_fsx else None,\n",
Expand Down Expand Up @@ -643,18 +610,7 @@
},
"outputs": [],
"source": [
"if not running_ci:\n",
" smp_estimator.fit(\n",
" inputs=data_channels,\n",
" experiment_config={\n",
" \"ExperimentName\": experiment.experiment_name,\n",
" \"TrialName\": trial.trial_name,\n",
" \"TrialComponentDisplayName\": \"Training\",\n",
" },\n",
" logs=True,\n",
" )\n",
"else:\n",
" smp_estimator.fit(inputs=data_channels, logs=True)"
"smp_estimator.fit(inputs=data_channels, logs=True)"
]
},
{
Expand Down
Loading

0 comments on commit ccc102a

Please sign in to comment.