Skip to content

Commit

Permalink
Add batch transform job to the notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Aloha106 committed Sep 11, 2018
1 parent d496ecc commit 326f713
Showing 1 changed file with 150 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
"3. [Fine-tuning The Image Classification Model](#Fine-tuning-the-Image-classification-model)\n",
" 1. [Training parameters](#Training-parameters)\n",
" 2. [Training](#Training)\n",
"4. [Set Up Hosting For The Model](#Set-up-hosting-for-the-model)\n",
"4. [Deploy The Model](#Deploy-the-model)\n",
" 1. [Create model](#Create-model)\n",
" 2. [Create endpoint configuration](#Create-endpoint-configuration)\n",
" 3. [Create endpoint](#Create-endpoint)\n",
" 4. [Perform inference](#Perform-inference)"
" 2. [Batch Transform](#Batch-transform)\n",
" 3. [Realtime inference](#Realtime-inference)\n",
" 1. [Create endpoint configuration](#Create-endpoint-configuration) \n",
" 2. [Create endpoint](#Create-endpoint) \n",
" 3. [Perform inference](#Perform-inference) \n",
" 4. [Clean up](#Clean-up)"
]
},
{
Expand Down Expand Up @@ -163,10 +166,10 @@
"outputs": [],
"source": [
"# Four channels: train, validation, train_lst, and validation_lst\n",
"s3train = 's3://{}/train/'.format(bucket)\n",
"s3validation = 's3://{}/validation/'.format(bucket)\n",
"s3train_lst = 's3://{}/train_lst/'.format(bucket)\n",
"s3validation_lst = 's3://{}/validation_lst/'.format(bucket)\n",
"s3train = 's3://{}/image-classification/train/'.format(bucket)\n",
"s3validation = 's3://{}/image-classification/validation/'.format(bucket)\n",
"s3train_lst = 's3://{}/image-classification/train_lst/'.format(bucket)\n",
"s3validation_lst = 's3://{}/image-classification/validation_lst/'.format(bucket)\n",
"\n",
"# upload the image files to train and validation channels\n",
"!aws s3 cp caltech_256_train_60 $s3train --recursive --quiet\n",
Expand Down Expand Up @@ -343,7 +346,7 @@
" \"DataSource\": {\n",
" \"S3DataSource\": {\n",
" \"S3DataType\": \"S3Prefix\",\n",
" \"S3Uri\": 's3://{}/train/'.format(bucket),\n",
" \"S3Uri\": s3train\n",
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
" }\n",
" },\n",
Expand All @@ -355,7 +358,7 @@
" \"DataSource\": {\n",
" \"S3DataSource\": {\n",
" \"S3DataType\": \"S3Prefix\",\n",
" \"S3Uri\": 's3://{}/validation/'.format(bucket),\n",
" \"S3Uri\": s3validation,\n",
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
" }\n",
" },\n",
Expand All @@ -367,7 +370,7 @@
" \"DataSource\": {\n",
" \"S3DataSource\": {\n",
" \"S3DataType\": \"S3Prefix\",\n",
" \"S3Uri\": 's3://{}/train_lst/'.format(bucket),\n",
" \"S3Uri\": s3train_lst,\n",
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
" }\n",
" },\n",
Expand All @@ -379,7 +382,7 @@
" \"DataSource\": {\n",
" \"S3DataSource\": {\n",
" \"S3DataType\": \"S3Prefix\",\n",
" \"S3Uri\": 's3://{}/validation_lst/'.format(bucket),\n",
" \"S3Uri\": s3validation_lst,\n",
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
" }\n",
" },\n",
Expand Down Expand Up @@ -513,7 +516,140 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create endpoint configuration\n",
"### Batch transform\n",
"\n",
"We now create a SageMaker Batch Transform job using the model created above to perform batch prediction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
"batch_job_name=\"image-classification-model\" + timestamp\n",
"batch_input = s3validation + \"001.ak47/\"\n",
"request = \\\n",
"{\n",
" \"TransformJobName\": batch_job_name,\n",
" \"ModelName\": model_name,\n",
" \"MaxConcurrentTransforms\": 10,\n",
" \"MaxPayloadInMB\": 6,\n",
" \"BatchStrategy\": \"MultiRecord\",\n",
" \"TransformOutput\": {\n",
" \"S3OutputPath\": 's3://{}/{}/output'.format(bucket, batch_job_name)\n",
" },\n",
" \"TransformInput\": {\n",
" \"DataSource\": {\n",
" \"S3DataSource\": {\n",
" \"S3DataType\": \"S3Prefix\",\n",
" \"S3Uri\": batch_input\n",
" }\n",
" },\n",
" \"ContentType\": \"application/x-image\",\n",
" \"SplitType\": \"None\",\n",
" \"CompressionType\": \"None\"\n",
" },\n",
" \"TransformResources\": {\n",
" \"InstanceType\": \"ml.p2.xlarge\",\n",
" \"InstanceCount\": 1\n",
" }\n",
"}\n",
"\n",
"print('Transform job name: {}'.format(batch_job_name))\n",
"print('\\nInput Data Location: {}'.format(batch_input))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sagemaker = boto3.client('sagemaker')\n",
"sagemaker.create_transform_job(**request)\n",
"\n",
"print(\"Created Transform job with name: \", batch_job_name)\n",
"\n",
"while(True):\n",
" response = sagemaker.describe_transform_job(TransformJobName=batch_job_name)\n",
" status = response['TransformJobStatus']\n",
" if status == 'Completed':\n",
" print(\"Transform job ended with status: \" + status)\n",
" break\n",
" if status == 'Failed':\n",
" message = response['FailureReason']\n",
" print('Transform failed with the following error: {}'.format(message))\n",
" raise Exception('Transform job failed') \n",
" time.sleep(30) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After the job completes, let's check the prediction results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from urllib.parse import urlparse\n",
"import json\n",
"import numpy as np\n",
"\n",
"s3_client = boto3.client('s3')\n",
"object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']\n",
"\n",
"def list_objects(s3_client, bucket, prefix):\n",
" response = s3_client.list_objects(Bucket=bucket, Prefix=prefix)\n",
" objects = [content['Key'] for content in response['Contents']]\n",
" return objects\n",
"\n",
"def get_label(s3_client, bucket, prefix):\n",
" filename = prefix.split('/')[-1]\n",
" s3_client.download_file(bucket, prefix, filename)\n",
" with open(filename) as f:\n",
" data = json.load(f)\n",
" index = np.argmax(data['prediction'])\n",
" probability = data['prediction'][index]\n",
" print(\"Result: label - \" + object_categories[index] + \", probability - \" + str(probability))\n",
" return object_categories[index], probability\n",
"\n",
"inpute_images = list_objects(s3_client, bucket, urlparse(batch_input).path.lstrip('/'))\n",
"print(\"Sample inputs: \" + str(input_images[:2]))\n",
"\n",
"outputs = list_objects(s3_client, bucket, batch_job_name + \"/output\")\n",
"print(\"Sample output: \" + str(outputs[:2]))\n",
"\n",
"# Check prediction result of the first 3 images\n",
"[get_label(s3_client, bucket, prefix) for prefix in outputs[0:3]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Realtime inference\n",
"\n",
"We now host the model with an endpoint and perform realtime inference.\n",
"\n",
"This section involves several steps,\n",
"1. [Create endpoint configuration](#CreateEndpointConfiguration) - Create a configuration defining an endpoint.\n",
"1. [Create endpoint](#CreateEndpoint) - Use the configuration to create an inference endpoint.\n",
"1. [Perform inference](#PerformInference) - Perform inference on some input data using the endpoint.\n",
"1. [Clean up](#CleanUp) - Delete the endpoint and model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create endpoint configuration\n",
"At launch, we will support configuring REST endpoints in hosting with multiple models, e.g. for A/B testing purposes. In order to support this, customers create an endpoint configuration, that describes the distribution of traffic across the models, whether split, shadowed, or sampled in some way.\n",
"\n",
"In addition, the endpoint configuration describes the instance type required for model deployment, and at launch will describe the autoscaling configuration."
Expand Down Expand Up @@ -547,7 +683,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create endpoint\n",
"#### Create endpoint\n",
"Lastly, the customer creates the endpoint that serves up the model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 9-11 minutes to complete."
]
},
Expand Down

0 comments on commit 326f713

Please sign in to comment.