From bcf330e629d261995378d049a2995dc898dac601 Mon Sep 17 00:00:00 2001 From: Suraj Kota Date: Fri, 19 Aug 2022 01:28:59 +0000 Subject: [PATCH] formatting --- .../pytorch_mnist/pytorch_mnist.ipynb | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb b/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb index 8adb1993b1..39cbc5eda9 100644 --- a/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb +++ b/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb @@ -69,7 +69,7 @@ "sagemaker_session = sagemaker.Session()\n", "\n", "bucket = sagemaker_session.default_bucket()\n", - "prefix = 'sagemaker/DEMO-pytorch-mnist'\n", + "prefix = \"sagemaker/DEMO-pytorch-mnist\"\n", "\n", "role = sagemaker.get_execution_role()" ] @@ -111,14 +111,16 @@ "from torchvision.datasets import MNIST\n", "from torchvision import transforms\n", "\n", - "MNIST.mirrors = [\"https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/\"]\n", + "MNIST.mirrors = [\n", + " \"https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/\"\n", + "]\n", "\n", "MNIST(\n", - " 'data',\n", + " \"data\",\n", " download=True,\n", " transform=transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", - " )\n", + " ),\n", ")" ] }, @@ -144,8 +146,8 @@ } ], "source": [ - "inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)\n", - "print('input spec (in this case, just an S3 path): {}'.format(inputs))" + "inputs = sagemaker_session.upload_data(path=\"data\", bucket=bucket, key_prefix=prefix)\n", + "print(\"input spec (in this case, just an S3 path): {}\".format(inputs))" ] }, { @@ -202,16 +204,15 @@ "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", - "estimator = PyTorch(entry_point='mnist.py',\n", - " role=role,\n", - " py_version='py38',\n", - " framework_version='1.11.0',\n", - " instance_count=2,\n", - " instance_type='ml.c5.2xlarge',\n", - " hyperparameters={\n", - " 'epochs': 1,\n", - " 'backend': 'gloo'\n", - " })" + "estimator = PyTorch(\n", + " entry_point=\"mnist.py\",\n", + " role=role,\n", + " py_version=\"py38\",\n", + " framework_version=\"1.11.0\",\n", + " instance_count=2,\n", + " instance_type=\"ml.c5.2xlarge\",\n", + " hyperparameters={\"epochs\": 1, \"backend\": \"gloo\"},\n", + ")" ] }, { @@ -532,7 +533,7 @@ } ], "source": [ - "estimator.fit({'training': inputs})" + "estimator.fit({\"training\": inputs})" ] }, { @@ -562,7 +563,7 @@ } ], "source": [ - "predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')" + "predictor = estimator.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")" ] }, { @@ -600,16 +601,20 @@ "metadata": {}, "outputs": [], "source": [ - "import gzip \n", + "import gzip\n", "import numpy as np\n", "import random\n", "import os\n", "\n", - "data_dir = 'data/MNIST/raw'\n", + "data_dir = \"data/MNIST/raw\"\n", "with gzip.open(os.path.join(data_dir, \"t10k-images-idx3-ubyte.gz\"), \"rb\") as f:\n", - " images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28).astype(np.float32)\n", + " images = (\n", + " np.frombuffer(f.read(), np.uint8, offset=16)\n", + " .reshape(-1, 28, 28)\n", + " .astype(np.float32)\n", + " )\n", "\n", - "mask = random.sample(range(len(images)), 16) # randomly select some of the test images\n", + "mask = random.sample(range(len(images)), 16) # randomly select some of the test images\n", "mask = np.array(mask, dtype=np.int)\n", "data = images[mask]" ] @@ -710,9 +715,7 @@ "metadata": {}, "outputs": [], "source": [ - "sagemaker_session.delete_endpoint(\n", - " endpoint_name = predictor.endpoint_name\n", - ")" + "sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)" ] } ],