diff --git a/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb b/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb index 4c78865757..7c91e84339 100644 --- a/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb +++ b/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb @@ -69,7 +69,7 @@ "sagemaker_session = sagemaker.Session()\n", "\n", "bucket = sagemaker_session.default_bucket()\n", - "prefix = 'sagemaker/DEMO-pytorch-mnist'\n", + "prefix = \"sagemaker/DEMO-pytorch-mnist\"\n", "\n", "role = sagemaker.get_execution_role()" ] @@ -114,11 +114,11 @@ "MNIST.mirrors = [\"https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/\"]\n", "\n", "MNIST(\n", - " 'data',\n", + " \"data\",\n", " download=True,\n", " transform=transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", - " )\n", + " ),\n", ")" ] }, @@ -144,8 +144,8 @@ } ], "source": [ - "inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)\n", - "print('input spec (in this case, just an S3 path): {}'.format(inputs))" + "inputs = sagemaker_session.upload_data(path=\"data\", bucket=bucket, key_prefix=prefix)\n", + "print(\"input spec (in this case, just an S3 path): {}\".format(inputs))" ] }, { @@ -202,16 +202,15 @@ "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", - "estimator = PyTorch(entry_point='mnist.py',\n", - " role=role,\n", - " py_version='py3',\n", - " framework_version='1.8.0',\n", - " instance_count=2,\n", - " instance_type='ml.c5.2xlarge',\n", - " hyperparameters={\n", - " 'epochs': 1,\n", - " 'backend': 'gloo'\n", - " })" + "estimator = PyTorch(\n", + " entry_point=\"mnist.py\",\n", + " role=role,\n", + " py_version=\"py38\",\n", + " framework_version=\"1.11.0\",\n", + " instance_count=2,\n", + " instance_type=\"ml.c5.2xlarge\",\n", + " hyperparameters={\"epochs\": 1, \"backend\": \"gloo\"},\n", + ")" ] }, { @@ -532,7 +531,7 @@ } ], "source": [ - "estimator.fit({'training': inputs})" + "estimator.fit({\"training\": inputs})" ] }, { @@ -562,7 +561,7 @@ } ], "source": [ - "predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')" + "predictor = estimator.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")" ] }, { @@ -600,16 +599,16 @@ "metadata": {}, "outputs": [], "source": [ - "import gzip \n", + "import gzip\n", "import numpy as np\n", "import random\n", "import os\n", "\n", - "data_dir = 'data/MNIST/raw'\n", + "data_dir = \"data/MNIST/raw\"\n", "with gzip.open(os.path.join(data_dir, \"t10k-images-idx3-ubyte.gz\"), \"rb\") as f:\n", " images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28).astype(np.float32)\n", "\n", - "mask = random.sample(range(len(images)), 16) # randomly select some of the test images\n", + "mask = random.sample(range(len(images)), 16) # randomly select some of the test images\n", "mask = np.array(mask, dtype=np.int)\n", "data = images[mask]" ] @@ -710,9 +709,7 @@ "metadata": {}, "outputs": [], "source": [ - "sagemaker_session.delete_endpoint(\n", - " endpoint_name = predictor.endpoint_name\n", - ")" + "sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)" ] } ],