Skip to content

Commit

Permalink
Merge branch 'churn_new' of github.com:Neo9061/amazon-sagemaker-examp…
Browse files Browse the repository at this point in the history
…les into churn_new
  • Loading branch information
Neo9061 committed Aug 19, 2022
2 parents b3edb20 + 2d1ac70 commit f28eddf
Showing 1 changed file with 20 additions and 23 deletions.
43 changes: 20 additions & 23 deletions sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"sagemaker_session = sagemaker.Session()\n",
"\n",
"bucket = sagemaker_session.default_bucket()\n",
"prefix = 'sagemaker/DEMO-pytorch-mnist'\n",
"prefix = \"sagemaker/DEMO-pytorch-mnist\"\n",
"\n",
"role = sagemaker.get_execution_role()"
]
Expand Down Expand Up @@ -114,11 +114,11 @@
"MNIST.mirrors = [\"https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/\"]\n",
"\n",
"MNIST(\n",
" 'data',\n",
" \"data\",\n",
" download=True,\n",
" transform=transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
" )\n",
" ),\n",
")"
]
},
Expand All @@ -144,8 +144,8 @@
}
],
"source": [
"inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)\n",
"print('input spec (in this case, just an S3 path): {}'.format(inputs))"
"inputs = sagemaker_session.upload_data(path=\"data\", bucket=bucket, key_prefix=prefix)\n",
"print(\"input spec (in this case, just an S3 path): {}\".format(inputs))"
]
},
{
Expand Down Expand Up @@ -202,16 +202,15 @@
"source": [
"from sagemaker.pytorch import PyTorch\n",
"\n",
"estimator = PyTorch(entry_point='mnist.py',\n",
" role=role,\n",
" py_version='py3',\n",
" framework_version='1.8.0',\n",
" instance_count=2,\n",
" instance_type='ml.c5.2xlarge',\n",
" hyperparameters={\n",
" 'epochs': 1,\n",
" 'backend': 'gloo'\n",
" })"
"estimator = PyTorch(\n",
" entry_point=\"mnist.py\",\n",
" role=role,\n",
" py_version=\"py38\",\n",
" framework_version=\"1.11.0\",\n",
" instance_count=2,\n",
" instance_type=\"ml.c5.2xlarge\",\n",
" hyperparameters={\"epochs\": 1, \"backend\": \"gloo\"},\n",
")"
]
},
{
Expand Down Expand Up @@ -532,7 +531,7 @@
}
],
"source": [
"estimator.fit({'training': inputs})"
"estimator.fit({\"training\": inputs})"
]
},
{
Expand Down Expand Up @@ -562,7 +561,7 @@
}
],
"source": [
"predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')"
"predictor = estimator.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")"
]
},
{
Expand Down Expand Up @@ -600,16 +599,16 @@
"metadata": {},
"outputs": [],
"source": [
"import gzip \n",
"import gzip\n",
"import numpy as np\n",
"import random\n",
"import os\n",
"\n",
"data_dir = 'data/MNIST/raw'\n",
"data_dir = \"data/MNIST/raw\"\n",
"with gzip.open(os.path.join(data_dir, \"t10k-images-idx3-ubyte.gz\"), \"rb\") as f:\n",
" images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28).astype(np.float32)\n",
"\n",
"mask = random.sample(range(len(images)), 16) # randomly select some of the test images\n",
"mask = random.sample(range(len(images)), 16) # randomly select some of the test images\n",
"mask = np.array(mask, dtype=np.int)\n",
"data = images[mask]"
]
Expand Down Expand Up @@ -710,9 +709,7 @@
"metadata": {},
"outputs": [],
"source": [
"sagemaker_session.delete_endpoint(\n",
" endpoint_name = predictor.endpoint_name\n",
")"
"sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)"
]
}
],
Expand Down

0 comments on commit f28eddf

Please sign in to comment.