diff --git a/training/distributed_training/pytorch/data_parallel/mnist/code/train_pytorch_smdataparallel_mnist.py b/training/distributed_training/pytorch/data_parallel/mnist/code/train_pytorch_smdataparallel_mnist.py index b142c298ed..a63679feb4 100644 --- a/training/distributed_training/pytorch/data_parallel/mnist/code/train_pytorch_smdataparallel_mnist.py +++ b/training/distributed_training/pytorch/data_parallel/mnist/code/train_pytorch_smdataparallel_mnist.py @@ -18,20 +18,22 @@ import os import time -import smdistributed.dataparallel.torch.distributed as dist import torch import torchvision import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms # Network definition from model_def import Net # Import SMDataParallel PyTorch Modules -from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import StepLR -from torchvision import datasets, transforms +import smdistributed.dataparallel.torch.torch_smddp + # override dependency on mirrors provided by torch vision package # from torchvision 0.9.1, 2 candidate mirror website links will be added before "resources" items automatically @@ -64,9 +66,6 @@ class CUDANotFoundException(Exception): pass -dist.init_process_group() - - def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): @@ -169,10 +168,11 @@ def main(): help="Path for downloading " "the MNIST dataset", ) + dist.init_process_group(backend="smddp") args = parser.parse_args() args.world_size = dist.get_world_size() args.rank = rank = dist.get_rank() - args.local_rank = local_rank = dist.get_local_rank() + args.local_rank = local_rank = int(os.getenv("LOCAL_RANK", -1)) args.lr = 1.0 args.batch_size //= args.world_size // 8 args.batch_size = max(args.batch_size, 1) diff --git a/training/distributed_training/pytorch/data_parallel/mnist/pytorch_smdataparallel_mnist_demo.ipynb b/training/distributed_training/pytorch/data_parallel/mnist/pytorch_smdataparallel_mnist_demo.ipynb index a93dfbb7aa..9b17bde2f5 100644 --- a/training/distributed_training/pytorch/data_parallel/mnist/pytorch_smdataparallel_mnist_demo.ipynb +++ b/training/distributed_training/pytorch/data_parallel/mnist/pytorch_smdataparallel_mnist_demo.ipynb @@ -25,7 +25,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**NOTE:** This example requires SageMaker Python SDK v2.X." + "**NOTE:** This example requires SageMaker Python SDK v2.**." ] }, { @@ -34,7 +34,7 @@ "metadata": {}, "outputs": [], "source": [ - "pip install sagemaker --upgrade" + "!pip install sagemaker --upgrade" ] }, { @@ -45,7 +45,7 @@ "\n", "The following code cell defines `role` which is the IAM role ARN used to create and run SageMaker training and hosting jobs. This is the same IAM role used to create this SageMaker Notebook instance. \n", "\n", - "`role` must have permission to create a SageMaker training job and launch an endpoint to host a model. For granular policies you can use to grant these permissions, see [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html). If you do not require fine-tuned permissions for this demo, you can used the IAM managed policy AmazonSageMakerFullAccess to complete this demo. " + "`role` must have permission to create a SageMaker training job and launch an endpoint to host a model. For granular policies you can use to grant these permissions, see [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html)." ] }, { @@ -86,7 +86,7 @@ "\n", "The MNIST dataset is downloaded using the `torchvision.datasets` PyTorch module; you can see how this is implemented in the `train_pytorch_smdataparallel_mnist.py` training script that is printed out in the next cell.\n", "\n", - "The training script provides the code you need for distributed data parallel (DDP) training using SageMaker's distributed data parallel library (`smdistributed.dataparallel`). The training script is very similar to a PyTorch training script you might run outside of SageMaker, but modified to run with the `smdistributed.dataparallel` library. This library's PyTorch client provides an alternative to PyTorch's native DDP. \n", + "The training script provides the code you need for distributed data parallel (DDP) training using SageMaker's distributed data parallel library (`smdistributed.dataparallel`). The training script is very similar to a PyTorch training script you might run outside SageMaker, but modified to run with the `smdistributed.dataparallel` library. This library's PyTorch client provides an alternative to PyTorch's native DDP.\n", "\n", "For details about how to use `smdistributed.dataparallel`'s DDP in your native PyTorch script, see the [Modify a PyTorch Training Script Using SMD Data Parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-modify-sdp.html#data-parallel-modify-sdp-pt)." ] @@ -106,7 +106,7 @@ "source": [ "### Estimator function options\n", "\n", - "In the following code block, you can update the estimator function to use a different instance type, instance count, and distrubtion strategy. You're also passing in the training script you reviewed in the previous cell to this estimator.\n", + "In the following code block, you can update the estimator function to use a different instance type, instance count, and distribution strategy. You're also passing in the training script you reviewed in the previous cell to this estimator.\n", "\n", "**Instance types**\n", "\n", @@ -122,7 +122,7 @@ "\n", "**Distribution strategy**\n", "\n", - "Note that to use DDP mode, you update the the `distribution` strategy, and set it to use `smdistributed dataparallel`. " + "Note that to use DDP mode, you update the `distribution` strategy, and set it to use `smdistributed dataparallel`." ] }, { @@ -138,12 +138,12 @@ " source_dir=\"code\",\n", " entry_point=\"train_pytorch_smdataparallel_mnist.py\",\n", " role=role,\n", - " framework_version=\"1.8.1\",\n", - " py_version=\"py36\",\n", + " framework_version=\"1.11.0\",\n", + " py_version=\"py38\",\n", " # For training with multinode distributed training, set this count. Example: 2\n", " instance_count=1,\n", " # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge\n", - " instance_type=\"ml.p3dn.24xlarge\",\n", + " instance_type=\"ml.p4d.24xlarge\",\n", " sagemaker_session=sagemaker_session,\n", " # Training using SMDataParallel Distributed Training Framework\n", " distribution={\"smdistributed\": {\"dataparallel\": {\"enabled\": True}}},\n", @@ -202,4 +202,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file