Skip to content

Commit

Permalink
Update SMDDP PT RNN-T example notebook (aws#3422)
Browse files Browse the repository at this point in the history
* update rnnt notebook

* udpate docker tag

* formatting

* grammar

* grammar

* update pt version in markdown

Co-authored-by: Carolyn Wang <[email protected]>
Co-authored-by: atqy <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2022
1 parent 4e181db commit 2199c4b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,34 @@
"cells": [
{
"cell_type": "markdown",
"id": "55659189",
"id": "da30579c",
"metadata": {},
"source": [
"# Distributed data parallel RNN-T training with PyTorch and SageMaker distributed\n",
"\n",
"[Amazon SageMaker's distributed library](https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-training.html) can be used to train deep learning models faster and cheaper. The [data parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html) feature in this library (`smdistributed.dataparallel`) is a distributed data parallel training framework for PyTorch, TensorFlow, and MXNet.\n",
"\n",
"This notebook demonstrates how to use `smdistributed.dataparallel` with PyTorch(version 1.8.1) on [Amazon SageMaker](https://aws.amazon.com/sagemaker/) to train an RNN-T model on [LibriSpeech](http://www.openslr.org/12) (License: [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/)) using [Amazon FSx for Lustre file-system](https://aws.amazon.com/fsx/lustre/) as data source.\n",
"This notebook demonstrates how to use `smdistributed.dataparallel` with PyTorch(version 1.10.2) on [Amazon SageMaker](https://aws.amazon.com/sagemaker/) to train an RNN-T model on [`LibriSpeech`](http://www.openslr.org/12) (License: [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/)) using [Amazon FSx for Lustre file-system](https://aws.amazon.com/fsx/lustre/) as data source.\n",
"\n",
"The outline of steps is as follows:\n",
"\n",
"1. Stage the LibriSpeech dataset in [Amazon S3](https://aws.amazon.com/s3/)\n",
"1. Stage the `LibriSpeech` dataset on [Amazon S3](https://aws.amazon.com/s3/)\n",
"2. Create Amazon FSx Lustre file-system and import data into the file-system from S3\n",
"3. Build Docker training image and push it to [Amazon ECR](https://aws.amazon.com/ecr/)\n",
"4. Configure data input channels for SageMaker\n",
"5. Configure hyper-prarameters\n",
"6. Define training metrics\n",
"7. Define training job, set distribution strategy to SMDataParallel and start training\n",
"7. Define training job, set distribution strategy to `SMDataParallel` and start training\n",
"\n",
"**NOTE:** With large training dataset, we recommend using [Amazon FSx](https://aws.amazon.com/fsx/) as the input file system for the SageMaker training job. FSx file input to SageMaker significantly cuts down training start up time on SageMaker because it avoids downloading the training data each time you start the training job (as done with S3 input for SageMaker training job) and provides good data read throughput.\n",
"\n",
"\n",
"**NOTE:** This example requires SageMaker Python SDK v2.X."
"**NOTE:** This example requires `SageMaker Python SDK v2.X`."
]
},
{
"cell_type": "markdown",
"id": "1901d71a",
"id": "62f99a96",
"metadata": {},
"source": [
"## Amazon SageMaker Initialization\n",
Expand All @@ -40,15 +40,15 @@
"\n",
"The following code cell defines `role` which is the IAM role ARN used to create and run SageMaker training and hosting jobs. This is the same IAM role used to create this SageMaker Notebook instance. \n",
"\n",
"`role` must have permission to create a SageMaker training job and host a model. For granular policies you can use to grant these permissions, see [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html). If you do not require fine-tuned permissions for this demo, you can use the IAM managed policy AmazonSageMakerFullAccess to complete this demo. \n",
"`role` must have permission to create a SageMaker training job and host a model. For granular policies you can use to grant these permissions, see [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html). If you do not require fine-tuned permissions for this demo, you can use the IAM managed policy `AmazonSageMakerFullAccess` to complete this demo. \n",
"\n",
"As described above, since we will be using FSx, please make sure to attach `FSx Access` permission to this IAM role."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6fb79f1",
"id": "458ee12e",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -80,7 +80,7 @@
},
{
"cell_type": "markdown",
"id": "099b092e",
"id": "df1ac155",
"metadata": {},
"source": [
"To verify that the role above has required permissions:\n",
Expand All @@ -94,60 +94,74 @@
},
{
"cell_type": "markdown",
"id": "ede42360",
"id": "40310447",
"metadata": {},
"source": [
"## Prepare SageMaker Training Images\n",
"\n",
"1. SageMaker by default uses the latest [Amazon Deep Learning Container Images (DLC)](https://github.com/aws/deep-learning-containers/blob/master/available_images.md) PyTorch training image. In this step, we use it as a base image and install additional dependencies required for training the RNN-T model.\n",
"2. In the Github repository https://github.com/HerringForks/SMDDP-Examples/tree/main/pytorch/rnnt we have forked an RNN-T example from [mlcommons/\n",
"training_results_v1.0](https://github.com/mlcommons/training_results_v1.0/tree/master/NVIDIA/benchmarks/rnnt/implementations/pytorch) and adapted the training script to work with `smdistributed.dataparallel`. We will use the `Dockerfile` provided there."
"2. In the GitHub repository https://github.com/HerringForks/SMDDP-Examples/tree/main/pytorch/rnnt we have forked an RNN-T example from [ml commons/training_results_v1.0](https://github.com/mlcommons/training_results_v1.0/tree/master/NVIDIA/benchmarks/rnnt/implementations/pytorch) and adapted the training script to work with `smdistributed.dataparallel`. We will use the `Dockerfile` provided there."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4bfbccfc",
"id": "c2470ba3",
"metadata": {},
"outputs": [],
"source": [
"# clone the repo and build the docker image\n",
"! pwd && rm -rf SMDDP-Examples && \\\n",
" aws ecr get-login-password --region {region} | docker login \\\n",
" --username AWS --password-stdin 763104351884.dkr.ecr.{region}.amazonaws.com && \\\n",
" git clone https://github.com/HerringForks/SMDDP-Examples.git && \\\n",
" cd SMDDP-Examples/pytorch/rnnt && \\\n",
" bash scripts/docker/build.sh\n",
" "
"# login to ecr\n",
"! aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin 763104351884.dkr.ecr.{region}.amazonaws.com"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dda17648",
"id": "47b97f90",
"metadata": {},
"outputs": [],
"source": [
"# clone the SMDDP-example repo\n",
"! rm -rf SMDDP-Examples && git clone https://github.com/HerringForks/SMDDP-Examples && cd SMDDP-Examples/pytorch/rnnt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3eb2d3e5",
"metadata": {},
"outputs": [],
"source": [
"# build the image\n",
"! cd SMDDP-Examples/pytorch/rnnt && bash scripts/docker/build.sh {region}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed65e9b7",
"metadata": {},
"outputs": [],
"source": [
"# name the image\n",
"image = \"zhaoqi-dev\" # Example: mask-rcnn-smdataparallel-sagemaker\n",
"tag = \"rnnt_dlc_pt1.8.1_smddp\" # Example: pt1.8"
"image = \"rnnt-smdataparallel-sagemaker\" # Example: rnnt-smdataparallel-sagemaker\n",
"tag = \"pt1.10.2\" # Example: pt1.10.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97bd6481",
"id": "ba3f3865",
"metadata": {},
"outputs": [],
"source": [
"# tag the image we just built and push it to ecr\n",
"%%time\n",
"! chmod +x tag_and_push.sh; bash tag_and_push.sh {region} {image} {tag}"
]
},
{
"cell_type": "markdown",
"id": "c04fca1f",
"id": "3c4f0f3d",
"metadata": {},
"source": [
"## Preparing FSx Input for SageMaker\n",
Expand All @@ -165,7 +179,7 @@
},
{
"cell_type": "markdown",
"id": "8873826a",
"id": "76f9e422",
"metadata": {},
"source": [
"## SageMaker PyTorch Estimator function options\n",
Expand All @@ -174,15 +188,15 @@
"\n",
"**Instance types**\n",
"\n",
"SMDataParallel supports model training on SageMaker with the following instance types only. For best performance, it is recommended you use an instance type that supports Amazon Elastic Fabric Adapter (ml.p3dn.24xlarge and ml.p4d.24xlarge).\n",
"`SMDataParallel` supports model training on SageMaker with the following instance types only. For best performance, it is recommended you use an instance type that supports Amazon Elastic Fabric Adapter (ml.p3dn.24xlarge and ml.p4d.24xlarge).\n",
"\n",
"1. ml.p3.16xlarge\n",
"1. ml.p3dn.24xlarge [Recommended]\n",
"1. ml.p4d.24xlarge [Recommended]\n",
"\n",
"**Instance count**\n",
"\n",
"To get the best performance and the most out of SMDataParallel, you should use at least 2 instances, but you can also use 1 for testing this example.\n",
"To get the best performance and the most out of `SMDataParallel`, you should use at least 2 instances, but you can also use 1 for testing this example.\n",
"\n",
"**Distribution strategy**\n",
"\n",
Expand All @@ -192,7 +206,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e948387d",
"id": "7f29e5f9",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -203,7 +217,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "3dcf432f",
"id": "06c048ef",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -222,7 +236,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c4dcc47e",
"id": "c5cb8d24",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -280,7 +294,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d2e5811e",
"id": "eb851f8d",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -291,32 +305,29 @@
" source_dir=\".\",\n",
" instance_count=instance_count,\n",
" instance_type=instance_type,\n",
" framework_version=\"1.8.1\",\n",
" py_version=\"py36\",\n",
" py_version=\"py38\",\n",
" sagemaker_session=sagemaker_session,\n",
" hyperparameters=hyperparameters,\n",
" subnets=subnets,\n",
" security_group_ids=security_group_ids,\n",
" debugger_hook_config=False,\n",
" # Training using SMDataParallel Distributed Training Framework\n",
" # Training using SageMaker distributed dataparallel Distributed Training Framework\n",
" distribution={\"smdistributed\": {\"dataparallel\": {\"enabled\": True}}},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "712890a9",
"id": "65793fb7",
"metadata": {},
"outputs": [],
"source": [
"# Configure FSx Input for your SageMaker Training job\n",
"\n",
"from sagemaker.inputs import FileSystemInput\n",
"\n",
"file_system_directory_path = (\n",
" \"/<mount_name>\" # NOTE: '/fsx/' will be the root mount path. Example: '/fsx/mask_rcnn/PyTorch'\n",
")\n",
"file_system_directory_path = \"/<mount_name>/<path_to_dataset>\" # NOTE: '/fsx/' will be the root mount path. Example: '/fsx/rnnt/PyTorch'\n",
"file_system_access_mode = \"ro\"\n",
"file_system_type = \"FSxLustre\"\n",
"train_fs = FileSystemInput(\n",
Expand All @@ -331,7 +342,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "82c4ef3e",
"id": "aa66fdcb",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -342,7 +353,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e8fa5b62",
"id": "fca0ffd2",
"metadata": {},
"outputs": [],
"source": []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,31 @@
import sys
import os

exe = 'python'
exe = "python"

trainer = '/workspace/rnnt/train.py'
trainer = "/workspace/rnnt/train.py"

cmd_list = [exe] + [trainer] + sys.argv[1:]
cmd = ' '.join(cmd_list)
cmd = " ".join(cmd_list)

cmd += ' '
cmd += '--dataset_dir ' + os.environ['SM_CHANNEL_TRAIN'] + '/datasets/LibriSpeech/ '
cmd += '--output_dir ' + os.environ['SM_OUTPUT_DIR'] + ' '
cmd += '--val_manifests ' + os.environ['SM_CHANNEL_TRAIN'] + '/tokenized/librispeech-dev-clean-wav-tokenized.pkl '
cmd += '--train_manifests ' + os.environ['SM_CHANNEL_TRAIN'] + '/tokenized/librispeech-train-clean-100-wav-tokenized.pkl ' + os.environ['SM_CHANNEL_TRAIN'] + '/tokenized/librispeech-train-clean-360-wav-tokenized.pkl ' + os.environ['SM_CHANNEL_TRAIN'] + '/tokenized/librispeech-train-other-500-wav-tokenized.pkl '
cmd += " "
cmd += "--dataset_dir " + os.environ["SM_CHANNEL_TRAIN"] + "/datasets/LibriSpeech/ "
cmd += "--output_dir " + os.environ["SM_OUTPUT_DIR"] + " "
cmd += (
"--val_manifests "
+ os.environ["SM_CHANNEL_TRAIN"]
+ "/tokenized/librispeech-dev-clean-wav-tokenized.pkl "
)
cmd += (
"--train_manifests "
+ os.environ["SM_CHANNEL_TRAIN"]
+ "/tokenized/librispeech-train-clean-100-wav-tokenized.pkl "
+ os.environ["SM_CHANNEL_TRAIN"]
+ "/tokenized/librispeech-train-clean-360-wav-tokenized.pkl "
+ os.environ["SM_CHANNEL_TRAIN"]
+ "/tokenized/librispeech-train-other-500-wav-tokenized.pkl "
)

print('Final command is: ', cmd)
print("Final command is: ", cmd)

subprocess.run(cmd, shell=True)
subprocess.run(cmd, shell=True)

0 comments on commit 2199c4b

Please sign in to comment.