Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix smddp mnist #3459

Merged
merged 8 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@
import os
import time

import smdistributed.dataparallel.torch.distributed as dist
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms

# Network definition
from model_def import Net

# Import SMDataParallel PyTorch Modules
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import smdistributed.dataparallel.torch.torch_smddp


# override dependency on mirrors provided by torch vision package
# from torchvision 0.9.1, 2 candidate mirror website links will be added before "resources" items automatically
Expand Down Expand Up @@ -64,9 +66,6 @@ class CUDANotFoundException(Exception):
pass


dist.init_process_group()


def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
Expand Down Expand Up @@ -169,10 +168,11 @@ def main():
help="Path for downloading " "the MNIST dataset",
)

dist.init_process_group(backend="smddp")
args = parser.parse_args()
args.world_size = dist.get_world_size()
args.rank = rank = dist.get_rank()
args.local_rank = local_rank = dist.get_local_rank()
args.local_rank = local_rank = int(os.getenv("LOCAL_RANK", -1))
args.lr = 1.0
args.batch_size //= args.world_size // 8
args.batch_size = max(args.batch_size, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**NOTE:** This example requires SageMaker Python SDK v2.X."
"**NOTE:** This example requires SageMaker Python SDK v2.**."
]
},
{
Expand All @@ -34,7 +34,7 @@
"metadata": {},
"outputs": [],
"source": [
"pip install sagemaker --upgrade"
"!pip install sagemaker --upgrade"
]
},
{
Expand All @@ -45,7 +45,7 @@
"\n",
"The following code cell defines `role` which is the IAM role ARN used to create and run SageMaker training and hosting jobs. This is the same IAM role used to create this SageMaker Notebook instance. \n",
"\n",
"`role` must have permission to create a SageMaker training job and launch an endpoint to host a model. For granular policies you can use to grant these permissions, see [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html). If you do not require fine-tuned permissions for this demo, you can used the IAM managed policy AmazonSageMakerFullAccess to complete this demo. "
"`role` must have permission to create a SageMaker training job and launch an endpoint to host a model. For granular policies you can use to grant these permissions, see [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html)."
]
},
{
Expand Down Expand Up @@ -86,7 +86,7 @@
"\n",
"The MNIST dataset is downloaded using the `torchvision.datasets` PyTorch module; you can see how this is implemented in the `train_pytorch_smdataparallel_mnist.py` training script that is printed out in the next cell.\n",
"\n",
"The training script provides the code you need for distributed data parallel (DDP) training using SageMaker's distributed data parallel library (`smdistributed.dataparallel`). The training script is very similar to a PyTorch training script you might run outside of SageMaker, but modified to run with the `smdistributed.dataparallel` library. This library's PyTorch client provides an alternative to PyTorch's native DDP. \n",
"The training script provides the code you need for distributed data parallel (DDP) training using SageMaker's distributed data parallel library (`smdistributed.dataparallel`). The training script is very similar to a PyTorch training script you might run outside SageMaker, but modified to run with the `smdistributed.dataparallel` library. This library's PyTorch client provides an alternative to PyTorch's native DDP.\n",
"\n",
"For details about how to use `smdistributed.dataparallel`'s DDP in your native PyTorch script, see the [Modify a PyTorch Training Script Using SMD Data Parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-modify-sdp.html#data-parallel-modify-sdp-pt)."
]
Expand All @@ -106,7 +106,7 @@
"source": [
"### Estimator function options\n",
"\n",
"In the following code block, you can update the estimator function to use a different instance type, instance count, and distrubtion strategy. You're also passing in the training script you reviewed in the previous cell to this estimator.\n",
"In the following code block, you can update the estimator function to use a different instance type, instance count, and distribution strategy. You're also passing in the training script you reviewed in the previous cell to this estimator.\n",
"\n",
"**Instance types**\n",
"\n",
Expand All @@ -122,7 +122,7 @@
"\n",
"**Distribution strategy**\n",
"\n",
"Note that to use DDP mode, you update the the `distribution` strategy, and set it to use `smdistributed dataparallel`. "
"Note that to use DDP mode, you update the `distribution` strategy, and set it to use `smdistributed dataparallel`."
]
},
{
Expand All @@ -138,12 +138,12 @@
" source_dir=\"code\",\n",
" entry_point=\"train_pytorch_smdataparallel_mnist.py\",\n",
" role=role,\n",
" framework_version=\"1.8.1\",\n",
" py_version=\"py36\",\n",
" framework_version=\"1.11.0\",\n",
" py_version=\"py38\",\n",
" # For training with multinode distributed training, set this count. Example: 2\n",
" instance_count=1,\n",
" # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge\n",
" instance_type=\"ml.p3dn.24xlarge\",\n",
" instance_type=\"ml.p4d.24xlarge\",\n",
" sagemaker_session=sagemaker_session,\n",
" # Training using SMDataParallel Distributed Training Framework\n",
" distribution={\"smdistributed\": {\"dataparallel\": {\"enabled\": True}}},\n",
Expand Down Expand Up @@ -202,4 +202,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}