From 0dd0a7dee2916e147230866219e2be10fb0112b3 Mon Sep 17 00:00:00 2001 From: Loki Date: Sat, 11 Jun 2022 22:00:03 +0000 Subject: [PATCH] Finishing up notebook and black checks --- .../single_gpu_single_node/scripts/vit.py | 417 +++++++++++ .../vision-transformer.ipynb | 677 +++++++++++++++--- 2 files changed, 994 insertions(+), 100 deletions(-) create mode 100644 sagemaker-training-compiler/tensorflow/single_gpu_single_node/scripts/vit.py diff --git a/sagemaker-training-compiler/tensorflow/single_gpu_single_node/scripts/vit.py b/sagemaker-training-compiler/tensorflow/single_gpu_single_node/scripts/vit.py new file mode 100644 index 0000000000..a5e0bb1919 --- /dev/null +++ b/sagemaker-training-compiler/tensorflow/single_gpu_single_node/scripts/vit.py @@ -0,0 +1,417 @@ +import math +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import mixed_precision +import tensorflow_addons as tfa +import matplotlib.pyplot as plt +from tensorflow.keras import layers +import argparse, os +from PIL import Image +from tensorflow.keras.applications import ResNet152 +from tensorflow.keras.layers import Dense, Flatten +from tensorflow.keras.models import Sequential +import warnings + + +mixed_precision.set_global_policy("mixed_float16") + +# Setting seed for reproducibiltiy +SEED = 42 +keras.utils.set_random_seed(SEED) + + +def normalize(arr): + """ + Linear normalization + http://en.wikipedia.org/wiki/Normalization_%28image_processing%29 + """ + arr = arr.astype(np.float64) + for i in range(3): + minval = arr[..., i].min() + maxval = arr[..., i].max() + if minval != maxval: + arr[..., i] -= minval + arr[..., i] *= 255.0 / (maxval - minval) + return arr.astype(np.uint8) + + +def resize(INPUT_SHAPE, img): + """Resize image to specified size. + + Resize shorter side to specified shape while maintaining aspect ratio. + """ + aspect_ratio = img.size[0] / img.size[1] + _size = [0, 0] + if img.size[0] < img.size[1]: + _size[0] = INPUT_SHAPE[0] + _size[1] = int(np.ceil(_size[0] / aspect_ratio)) + else: + _size[1] = INPUT_SHAPE[1] + _size[0] = int(np.ceil(_size[1] * aspect_ratio)) + return img.resize(tuple(_size)) + + +def load_dataset(INPUT_SHAPE, NUM_CLASSES): + """Load the Caltech-256 dataset from SageMaker input directory. + + The images are expected to be .jpg format stored under directories + that indicate their object category. Images smaller than the specified + size are ignored. + + Qualifying images are then resized and center cropped to meet the + size criterion specificed. Labels are obtained from the directory structure. + """ + x_train, y_train = [], [] + for root, dirs, files in os.walk(os.environ["SM_INPUT_DIR"]): + for file in [f for f in files if f.endswith(".jpg")]: + fpath = os.path.join(root, file) + with Image.open(fpath) as img: + if img.size[0] < INPUT_SHAPE[0] or img.size[1] < INPUT_SHAPE[1]: + continue + else: + img = resize(INPUT_SHAPE, img) + array = np.asarray(img) + margin = [0, 0] + for dim in [0, 1]: + diff = array.shape[dim] - INPUT_SHAPE[dim] + margin[dim] = diff // 2 + array = array[ + margin[0] : margin[0] + INPUT_SHAPE[0], + margin[1] : margin[1] + INPUT_SHAPE[1], + ] + try: + assert array.shape[2] == 3 + x_train.append(array) + except: + continue + label = int(fpath.split("/")[-2].split(".")[0]) + y_train.append(label) + return np.array(x_train, dtype=np.uint8), np.array(y_train, dtype=np.uint8) + + +ConfigDict = { + "dropout": 0.1, + "mlp_dim": 3072, + "num_heads": 12, + "num_layers": 12, + "hidden_size": 768, +} + + +def interpret_image_size(image_size_arg): + """Process the image_size argument whether a tuple or int.""" + if isinstance(image_size_arg, int): + return (image_size_arg, image_size_arg) + if ( + isinstance(image_size_arg, tuple) + and len(image_size_arg) == 2 + and all(map(lambda v: isinstance(v, int), image_size_arg)) + ): + return image_size_arg + raise ValueError( + f"The image_size argument must be a tuple of 2 integers or a single integer. Received: {image_size_arg}" + ) + + +@tf.keras.utils.register_keras_serializable() +class ClassToken(tf.keras.layers.Layer): + """Append a class token to an input layer.""" + + def build(self, input_shape): + cls_init = tf.zeros_initializer() + self.hidden_size = input_shape[-1] + self.cls = tf.Variable( + name="cls", + initial_value=cls_init(shape=(1, 1, self.hidden_size), dtype="float32"), + trainable=True, + ) + + def call(self, inputs): + batch_size = tf.shape(inputs)[0] + cls_broadcasted = tf.cast( + tf.broadcast_to(self.cls, [batch_size, 1, self.hidden_size]), + dtype=inputs.dtype, + ) + return tf.concat([cls_broadcasted, inputs], 1) + + def get_config(self): + config = super().get_config() + return config + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@tf.keras.utils.register_keras_serializable() +class AddPositionEmbs(tf.keras.layers.Layer): + """Adds (optionally learned) positional embeddings to the inputs.""" + + def build(self, input_shape): + assert len(input_shape) == 3, f"Number of dimensions should be 3, got {len(input_shape)}" + self.pe = tf.Variable( + name="pos_embedding", + initial_value=tf.random_normal_initializer(stddev=0.06)( + shape=(1, input_shape[1], input_shape[2]) + ), + dtype="float32", + trainable=True, + ) + + def call(self, inputs): + return inputs + tf.cast(self.pe, dtype=inputs.dtype) + + def get_config(self): + config = super().get_config() + return config + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@tf.keras.utils.register_keras_serializable() +class MultiHeadSelfAttention(tf.keras.layers.Layer): + def __init__(self, *args, num_heads, **kwargs): + super().__init__(*args, **kwargs) + self.num_heads = num_heads + + def build(self, input_shape): + hidden_size = input_shape[-1] + num_heads = self.num_heads + if hidden_size % num_heads != 0: + raise ValueError( + f"embedding dimension = {hidden_size} should be divisible by number of heads = {num_heads}" + ) + self.hidden_size = hidden_size + self.projection_dim = hidden_size // num_heads + self.query_dense = tf.keras.layers.Dense(hidden_size, name="query") + self.key_dense = tf.keras.layers.Dense(hidden_size, name="key") + self.value_dense = tf.keras.layers.Dense(hidden_size, name="value") + self.combine_heads = tf.keras.layers.Dense(hidden_size, name="out") + + # pylint: disable=no-self-use + def attention(self, query, key, value): + score = tf.matmul(query, key, transpose_b=True) + dim_key = tf.cast(tf.shape(key)[-1], score.dtype) + scaled_score = score / tf.math.sqrt(dim_key) + weights = tf.nn.softmax(scaled_score, axis=-1) + output = tf.matmul(weights, value) + return output, weights + + def separate_heads(self, x, batch_size): + x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim)) + return tf.transpose(x, perm=[0, 2, 1, 3]) + + def call(self, inputs): + batch_size = tf.shape(inputs)[0] + query = self.query_dense(inputs) + key = self.key_dense(inputs) + value = self.value_dense(inputs) + query = self.separate_heads(query, batch_size) + key = self.separate_heads(key, batch_size) + value = self.separate_heads(value, batch_size) + + attention, weights = self.attention(query, key, value) + attention = tf.transpose(attention, perm=[0, 2, 1, 3]) + concat_attention = tf.reshape(attention, (batch_size, -1, self.hidden_size)) + output = self.combine_heads(concat_attention) + return output, weights + + def get_config(self): + config = super().get_config() + config.update({"num_heads": self.num_heads}) + return config + + @classmethod + def from_config(cls, config): + return cls(**config) + + +# pylint: disable=too-many-instance-attributes +@tf.keras.utils.register_keras_serializable() +class TransformerBlock(tf.keras.layers.Layer): + """Implements a Transformer block.""" + + def __init__(self, *args, num_heads, mlp_dim, dropout, **kwargs): + super().__init__(*args, **kwargs) + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.dropout = dropout + + def build(self, input_shape): + self.att = MultiHeadSelfAttention( + num_heads=self.num_heads, + name="MultiHeadDotProductAttention_1", + ) + self.mlpblock = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.mlp_dim, + activation="linear", + name=f"{self.name}/Dense_0", + ), + tf.keras.layers.Lambda(lambda x: tf.keras.activations.gelu(x, approximate=False)) + if hasattr(tf.keras.activations, "gelu") + else tf.keras.layers.Lambda(lambda x: tfa.activations.gelu(x, approximate=False)), + tf.keras.layers.Dropout(self.dropout), + tf.keras.layers.Dense(input_shape[-1], name=f"{self.name}/Dense_1"), + tf.keras.layers.Dropout(self.dropout), + ], + name="MlpBlock_3", + ) + self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_0") + self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_2") + self.dropout_layer = tf.keras.layers.Dropout(self.dropout) + + def call(self, inputs, training): + x = self.layernorm1(inputs) + x, weights = self.att(x) + x = self.dropout_layer(x, training=training) + x = x + inputs + y = self.layernorm2(x) + y = self.mlpblock(y) + return x + y, weights + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "mlp_dim": self.mlp_dim, + "dropout": self.dropout, + } + ) + return config + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def build_model( + image_size, + patch_size, + num_layers, + hidden_size, + num_heads, + name, + mlp_dim, + classes, + dropout=0.1, + activation="linear", + include_top=True, + representation_size=None, +): + """Build a ViT model. + + Args: + image_size: The size of input images. + patch_size: The size of each patch (must fit evenly in image_size) + classes: optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + num_layers: The number of transformer layers to use. + hidden_size: The number of filters to use + num_heads: The number of transformer heads + mlp_dim: The number of dimensions for the MLP output in the transformers. + dropout_rate: fraction of the units to drop for dense layers. + activation: The activation to use for the final layer. + include_top: Whether to include the final classification layer. If not, + the output will have dimensions (batch_size, hidden_size). + representation_size: The size of the representation prior to the + classification layer. If None, no Dense layer is inserted. + """ + image_size_tuple = interpret_image_size(image_size) + assert (image_size_tuple[0] % patch_size == 0) and ( + image_size_tuple[1] % patch_size == 0 + ), "image_size must be a multiple of patch_size" + x = tf.keras.layers.Input(shape=(image_size_tuple[0], image_size_tuple[1], 3)) + y = tf.keras.layers.Conv2D( + filters=hidden_size, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + name="embedding", + )(x) + y = tf.keras.layers.Reshape((y.shape[1] * y.shape[2], hidden_size))(y) + y = ClassToken(name="class_token")(y) + y = AddPositionEmbs(name="Transformer/posembed_input")(y) + for n in range(num_layers): + y, _ = TransformerBlock( + num_heads=num_heads, + mlp_dim=mlp_dim, + dropout=dropout, + name=f"Transformer/encoderblock_{n}", + )(y) + y = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="Transformer/encoder_norm")(y) + y = tf.keras.layers.Lambda(lambda v: v[:, 0], name="ExtractToken")(y) + if representation_size is not None: + y = tf.keras.layers.Dense(representation_size, name="pre_logits", activation="tanh")(y) + if include_top: + y = tf.keras.layers.Dense(classes, name="head", activation=activation)(y) + return tf.keras.models.Model(inputs=x, outputs=y, name=name) + + +def vit_b16( + image_size: (224, 224), + classes=1000, + activation="linear", + include_top=True, +): + """Build ViT-B16. All arguments passed to build_model.""" + model = build_model( + **ConfigDict, + name="vit-b16", + patch_size=16, + image_size=image_size, + classes=classes, + activation=activation, + include_top=include_top, + representation_size=768, + ) + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Configure the VIT Training through Hyperparameters" + ) + parser.add_argument( + "--NUM_CLASSES", type=int, default=257, help="Number of classfication categories" + ) + parser.add_argument( + "--INPUT_SHAPE", type=int, nargs="+", default=[224, 224, 3], help="Shape of input to VIT" + ) + parser.add_argument("--BATCH_SIZE", type=int, help="Batch Size to use with the Hardware") + parser.add_argument( + "--LEARNING_RATE", type=float, default=0.001, help="Learning rate to use for the Optimizer" + ) + parser.add_argument( + "--WEIGHT_DECAY", type=float, default=0.0001, help="Weight decay to use for the Optimizer" + ) + parser.add_argument( + "--EPOCHS", type=int, default=1, help="Number of times to loop over the data" + ) + args, unused = parser.parse_known_args() + + args.INPUT_SHAPE = tuple(args.INPUT_SHAPE) + print(f"Training on Images of size {args.INPUT_SHAPE}") + + x_train, y_train = load_dataset(args.INPUT_SHAPE, args.NUM_CLASSES) + x_train = normalize(x_train) + print(f"Training on dataset size {x_train.shape}") + + model = vit_b16(image_size=tuple(args.INPUT_SHAPE[:2]), classes=args.NUM_CLASSES) + model.compile( + optimizer=tfa.optimizers.AdamW( + learning_rate=args.LEARNING_RATE, weight_decay=args.WEIGHT_DECAY + ), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), + ], + ) + model.fit(x_train, y_train, epochs=args.EPOCHS, batch_size=args.BATCH_SIZE, verbose=2) diff --git a/sagemaker-training-compiler/tensorflow/single_gpu_single_node/vision-transformer.ipynb b/sagemaker-training-compiler/tensorflow/single_gpu_single_node/vision-transformer.ipynb index 27480d2917..aca6691f4b 100644 --- a/sagemaker-training-compiler/tensorflow/single_gpu_single_node/vision-transformer.ipynb +++ b/sagemaker-training-compiler/tensorflow/single_gpu_single_node/vision-transformer.ipynb @@ -2,39 +2,40 @@ "cells": [ { "cell_type": "markdown", - "id": "cf50d633", + "id": "48d008c6", "metadata": {}, "source": [ - "# Compile and Train a Vision Transformer model on the ImageNet Dataset for Image Classification on a Single-Node Single-GPU" + "# Compile and Train a Vision Transformer Model on the Caltech 256 Dataset using a Single Node " ] }, { "cell_type": "markdown", - "id": "31e66f8f", + "id": "e3a3be53", "metadata": {}, "source": [ "1. [Introduction](#Introduction) \n", "2. [Development Environment and Permissions](#Development-Environment-and-Permissions)\n", " 1. [Installation](#Installation) \n", " 2. [SageMaker environment](#SageMaker-environment)\n", - "3. [Processing](#Preprocessing)\n", - " 1. [Tokenization](#Tokenization)\n", - " 2. [Uploading data to sagemaker_session_bucket](#Uploading-data-to-sagemaker_session_bucket)\n", - "4. [SageMaker Training Job](#SageMaker-Training-Job)\n", - " 1. [Training with Native TensorFlow](#Training-with-Native-TensorFlow) \n", - " 2. [Training with Optimized TensorFlow](#Training-with-Optimized-TensorFlow) \n", - " 3. [Analysis](#Analysis)\n", - "5. [Clean Up](#Clean-Up)\n" + "3. [Working with the Caltech-256 dataset](#Working-with-the-Caltech-256-dataset) \n", + "4. [SageMaker Training Job](#SageMaker-Training-Job) \n", + " 1. [Training Setup](#Training-Setup) \n", + " 2. [Training with Native TensorFlow](#Training-with-Native-TensorFlow) \n", + " 3. [Training with Optimized TensorFlow](#Training-with-Optimized-TensorFlow) \n", + "5. [Analysis](#Analysis)\n", + " 1. [Savings from Training Compiler](#Savings-from-Training-Compiler)\n", + " 2. [Convergence of Training](#Convergence-of-Training)\n", + "6. [Clean up](#Clean-up)\n" ] }, { "cell_type": "markdown", - "id": "5e3c714a", + "id": "420ad24e", "metadata": {}, "source": [ "## SageMaker Training Compiler Overview\n", "\n", - "SageMaker Training Compiler is a capability of SageMaker that makes these hard-to-implement optimizations to reduce training time on GPU instances. The compiler optimizes DL models to accelerate training by more efficiently using SageMaker machine learning (ML) GPU instances. SageMaker Training Compiler is available at no additional charge within SageMaker and can help reduce total billable time as it accelerates training. \n", + "SageMaker Training Compiler is a capability of SageMaker that makes hard-to-implement optimizations to reduce training time on GPU instances. The compiler optimizes DL models to accelerate training by more efficiently using SageMaker machine learning (ML) GPU instances. SageMaker Training Compiler is available at no additional charge within SageMaker and can help reduce total billable time as it accelerates training. \n", "\n", "SageMaker Training Compiler is integrated into the AWS Deep Learning Containers (DLCs). Using the SageMaker Training Compiler enabled AWS DLCs, you can compile and optimize training jobs on GPU instances with minimal changes to your code. Bring your deep learning models to SageMaker and enable SageMaker Training Compiler to accelerate the speed of your training job on SageMaker ML instances for accelerated computing. \n", "\n", @@ -42,110 +43,586 @@ "\n", "## Introduction\n", "\n", - "In this demo, you'll use Hugging Face's `transformers` and `datasets` libraries with Amazon SageMaker Training Compiler to train the `RoBERTa` model on the `Stanford Sentiment Treebank v2 (SST2)` dataset. To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on. \n", + "In this demo, you'll use Amazon SageMaker Training Compiler to train the `Vision Transformer` model on the `Caltech-256` dataset. To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on. \n", "\n", - "**NOTE:** You can run this demo in SageMaker Studio, SageMaker notebook instances, or your local machine with AWS CLI set up. If using SageMaker Studio or SageMaker notebook instances, make sure you choose one of the PyTorch-based kernels, `Python 3 (PyTorch x.y Python 3.x CPU Optimized)` or `conda_pytorch_p36` respectively.\n", + "**NOTE:** You can run this demo in SageMaker Studio, SageMaker notebook instances, or your local machine with AWS CLI set up. If using SageMaker Studio or SageMaker notebook instances, make sure you choose one of the TensorFlow-based kernels, `Python 3 (TensorFlow x.y Python 3.x CPU Optimized)` or `conda_tesorflow_p39` respectively.\n", "\n", - "**NOTE:** This notebook uses two `ml.p3.2xlarge` instances that have single GPU. If you don't have enough quota, see [Request a service quota increase for SageMaker resources](https://docs.aws.amazon.com/sagemaker/latest/dg/regions-quotas.html#service-limit-increase-request-procedure). " + "**NOTE:** This notebook uses a `ml.p3.2xlarge` instance with a single GPU. However, it can easily be extended to multiple GPUs on a single node. If you don't have enough quota, see [Request a service quota increase for SageMaker resources](https://docs.aws.amazon.com/sagemaker/latest/dg/regions-quotas.html#service-limit-increase-request-procedure). " + ] + }, + { + "cell_type": "markdown", + "id": "e86b8d5b", + "metadata": {}, + "source": [ + "## Development Environment \n" + ] + }, + { + "cell_type": "markdown", + "id": "8d6666d6", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "This example notebook requires **SageMaker Python SDK v2.92.0**\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "bc2ccaab", + "id": "8f1a4aca", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Cloning into '/var/folders/5r/j40pqpnd4lv66lxzrjmygjzr0000gs/T/tmpvf3uunxo'...\n", - "Note: switching to 'v2.9.2'.\n", - "\n", - "You are in 'detached HEAD' state. You can look around, make experimental\n", - "changes and commit them, and you can discard any commits you make in this\n", - "state without impacting any branches by switching back to a branch.\n", - "\n", - "If you want to create a new branch to retain commits you create, you may\n", - "do so (now or later) by using -c with the switch command. Example:\n", - "\n", - " git switch -c \n", - "\n", - "Or undo this operation with:\n", - "\n", - " git switch -\n", - "\n", - "Turn off this advice by setting config variable advice.detachedHead to false\n", - "\n", - "HEAD is now at 675d26469 Make preprocess_ops visible from tensorflow_models import.\n" - ] - } - ], + "outputs": [], "source": [ + "!pip install \"sagemaker>=2.92\" botocore boto3 awscli matplotlib --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a93e2a4", + "metadata": {}, + "outputs": [], + "source": [ + "import botocore\n", + "import boto3\n", + "import sagemaker\n", + "\n", + "print(f\"botocore: {botocore.__version__}\")\n", + "print(f\"boto3: {boto3.__version__}\")\n", + "print(f\"sagemaker: {sagemaker.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b30e484", + "metadata": {}, + "source": [ + "### SageMaker environment " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "946fe532", + "metadata": {}, + "outputs": [], + "source": [ + "import sagemaker\n", + "\n", + "sess = sagemaker.Session()\n", + "\n", + "# SageMaker session bucket -> used for uploading data, models and logs\n", + "# SageMaker will automatically create this bucket if it does not exist\n", + "sagemaker_session_bucket = None\n", + "if sagemaker_session_bucket is None and sess is not None:\n", + " # set to default bucket if a bucket name is not given\n", + " sagemaker_session_bucket = sess.default_bucket()\n", + "\n", + "role = sagemaker.get_execution_role()\n", + "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n", + "\n", + "print(f\"sagemaker role arn: {role}\")\n", + "print(f\"sagemaker bucket: {sagemaker_session_bucket}\")\n", + "print(f\"sagemaker session region: {sess.boto_region_name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "6774ae6f", + "metadata": {}, + "source": [ + "## Working with the Caltech-256 dataset\n", + "\n", + "We have hosted the [Caltech-256](https://authors.library.caltech.edu/7694/) dataset in S3 in us-west-2. We will transfer this dataset to your account and region for use with SageMaker Training.\n", + "\n", + "The dataset consists of JPEG images organized into directories with each directory representing an object cateogory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24aea98e", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "source = \"s3://sagemaker-sample-files/datasets/image/caltech-256/256_ObjectCategories\"\n", + "destn = f\"s3://{sagemaker_session_bucket}/caltech-256\"\n", + "\n", + "os.system(f\"aws s3 sync {source} {destn}\")" + ] + }, + { + "cell_type": "markdown", + "id": "0b4decc2", + "metadata": {}, + "source": [ + "## SageMaker Training Job\n", + "\n", + "To create a SageMaker training job, we use a `TensorFlow` estimator. Using the estimator, you can define which training script should SageMaker use through `entry_point`, which `instance_type` to use for training, which `hyperparameters` to pass, and so on.\n", + "\n", + "When a SageMaker training job starts, SageMaker takes care of starting and managing all the required machine learning instances, picks up the `TensorFlow` Deep Learning Container, uploads your training script, and downloads the data from `sagemaker_session_bucket` into the container at `/opt/ml/input/data`.\n", + "\n", + "In the following section, you learn how to set up two versions of the SageMaker `TensorFlow` estimator, a native one without the compiler and an optimized one with the compiler." + ] + }, + { + "cell_type": "markdown", + "id": "cf4f1bba", + "metadata": {}, + "source": [ + "### Training Setup\n", + "\n", + "Set up the basic configuration for training. Set `EPOCHS` to the number of times you would like to loop over the training data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80ed60f0", + "metadata": {}, + "outputs": [], + "source": [ + "TRCOMP_IMAGE_URI = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.9.1-gpu-py39-cu112-ubuntu20.04-sagemaker\"\n", + "EPOCHS = 10" + ] + }, + { + "cell_type": "markdown", + "id": "973c2c67", + "metadata": {}, + "source": [ + "### Training with Native TensorFlow\n", + "\n", + "The `BATCH_SIZE` in the following code cell is the maximum batch that can fit into the memory of an `ml.p3.2xlarge` instance while giving the best training speed. If you change the model, instance type, and other parameters, you need to do some experiments to find the largest batch size that will fit into GPU memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4db276b2", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.tensorflow import TensorFlow\n", + "\n", + "BATCH_SIZE = 64\n", + "LEARNING_RATE = 1e-3\n", + "WEIGHT_DECAY = 1e-4\n", + "\n", + "kwargs = dict(\n", + " source_dir=\"scripts\",\n", + " entry_point=\"vit_b16_1.py\",\n", + " model_dir=False,\n", + " instance_type=\"ml.p3.2xlarge\",\n", + " instance_count=1,\n", + " image_uri=TRCOMP_IMAGE_URI,\n", + " debugger_hook_config=None,\n", + " disable_profiler=True,\n", + " max_run=60 * 60, # 60 minutes\n", + " role=role,\n", + " metric_definitions=[\n", + " {\"Name\": \"training_loss\", \"Regex\": \"loss: ([0-9.]*?) \"},\n", + " {\"Name\": \"training_accuracy\", \"Regex\": \"accuracy: ([0-9.]*?) \"},\n", + " {\"Name\": \"training_latency_per_epoch\", \"Regex\": \"- ([0-9.]*?)s/epoch\"},\n", + " {\"Name\": \"training_avg_latency_per_step\", \"Regex\": \"- ([0-9.]*?)ms/step\"},\n", + " ],\n", + ")\n", + "\n", + "# Configure the training job\n", + "native_estimator = TensorFlow(\n", + " hyperparameters={\n", + " \"EPOCHS\": EPOCHS,\n", + " \"BATCH_SIZE\": BATCH_SIZE,\n", + " \"LEARNING_RATE\": LEARNING_RATE,\n", + " \"WEIGHT_DECAY\": WEIGHT_DECAY,\n", + " },\n", + " base_job_name=\"native-tf29-vit\",\n", + " **kwargs,\n", + ")\n", + "\n", + "# Start training with our uploaded datasets as input\n", + "native_estimator.fit(inputs=destn, wait=False)\n", + "\n", + "# The name of the training job.\n", + "native_estimator.latest_training_job.name" + ] + }, + { + "cell_type": "markdown", + "id": "0a49b0ac", + "metadata": {}, + "source": [ + "### Training with Optimized TensorFlow\n", + "\n", + "Compilation through Training Compiler changes the memory footprint of the model. Most commonly, this manifests as a reduction in memory utilization and a consequent increase in the largest batch size that can fit on the GPU. But in some case the compiler intelligently promotes caching which leads to a decrease in largest batch size that can fit on the GPU. Note that if you want to change the batch size, you must adjust the learning rate appropriately.\n", + "\n", + "**Note:** We recommend you to turn the SageMaker Debugger's profiling and debugging tools off when you use compilation to avoid additional overheads." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90d7a995", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Change how TrainingCompilerConfig is used after SDK release\n", + "\n", "from sagemaker.tensorflow import TensorFlow\n", "from sagemaker.training_compiler.config import TrainingCompilerConfig\n", "\n", - "import boto3\n", + "OPTIMIZED_BATCH_SIZE = 48\n", + "LEARNING_RATE = LEARNING_RATE / BATCH_SIZE * OPTIMIZED_BATCH_SIZE\n", + "WEIGHT_DECAY = WEIGHT_DECAY * BATCH_SIZE / OPTIMIZED_BATCH_SIZE\n", + "\n", + "# Configure the training job\n", + "optimized_estimator = TensorFlow(\n", + " hyperparameters={\n", + " TrainingCompilerConfig.HP_ENABLE_COMPILER: True,\n", + " \"EPOCHS\": EPOCHS,\n", + " \"BATCH_SIZE\": OPTIMIZED_BATCH_SIZE,\n", + " \"LEARNING_RATE\": LEARNING_RATE,\n", + " \"WEIGHT_DECAY\": WEIGHT_DECAY,\n", + " },\n", + " base_job_name=\"optimized-tf29-vit\",\n", + " **kwargs,\n", + ")\n", + "\n", + "# Start training with our uploaded datasets as input\n", + "optimized_estimator.fit(inputs=destn, wait=False)\n", + "\n", + "# The name of the training job.\n", + "optimized_estimator.latest_training_job.name" + ] + }, + { + "cell_type": "markdown", + "id": "d382f1dc", + "metadata": {}, + "source": [ + "### Wait for training jobs to complete\n", + "\n", + "The training jobs described above typically take around 40 mins to complete" + ] + }, + { + "cell_type": "markdown", + "id": "406462ea", + "metadata": {}, + "source": [ + "**Note:** If the estimator object is no longer available due to a kernel break or refresh, you need to directly use the training job name and manually attach the training job to a new TensorFlow estimator. For example:\n", + "\n", + "```python\n", + "native_estimator = TensorFlow.attach(\"\")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50b80394", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "native_estimator = TensorFlow.attach(native_estimator.latest_training_job.name)\n", + "optimized_estimator = TensorFlow.attach(optimized_estimator.latest_training_job.name)" + ] + }, + { + "cell_type": "markdown", + "id": "ca086c25", + "metadata": {}, + "source": [ + "## Analysis\n", + "\n", + "Here we view the training metrics from the training jobs as a Pandas dataframe\n", + "\n", + "#### Native TensorFlow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d431855", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Extract training metrics from the estimator\n", + "native_metrics = native_estimator.training_job_analytics.dataframe()\n", + "\n", + "# Restructure table for viewing\n", + "for metric in native_metrics[\"metric_name\"].unique():\n", + " native_metrics[metric] = native_metrics[native_metrics[\"metric_name\"] == metric][\"value\"]\n", + "native_metrics = native_metrics.drop(columns=[\"metric_name\", \"value\"])\n", + "native_metrics = native_metrics.groupby(\"timestamp\").max()\n", + "native_metrics[\"epochs\"] = range(1, 11)\n", + "native_metrics = native_metrics.set_index(\"epochs\")\n", + "\n", + "native_metrics" + ] + }, + { + "cell_type": "markdown", + "id": "a4b38ddb", + "metadata": {}, + "source": [ + "#### Optimized TensorFlow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "332e2252", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Extract training metrics from the estimator\n", + "optimized_metrics = optimized_estimator.training_job_analytics.dataframe()\n", + "\n", + "# Restructure table for viewing\n", + "for metric in optimized_metrics[\"metric_name\"].unique():\n", + " optimized_metrics[metric] = optimized_metrics[optimized_metrics[\"metric_name\"] == metric][\n", + " \"value\"\n", + " ]\n", + "optimized_metrics = optimized_metrics.drop(columns=[\"metric_name\", \"value\"])\n", + "optimized_metrics = optimized_metrics.groupby(\"timestamp\").max()\n", + "optimized_metrics[\"epochs\"] = range(1, 11)\n", + "optimized_metrics = optimized_metrics.set_index(\"epochs\")\n", + "\n", + "optimized_metrics" + ] + }, + { + "cell_type": "markdown", + "id": "7999562c", + "metadata": {}, + "source": [ + "### Savings from Training Compiler\n", + "\n", + "Let us calculate the actual savings on the training jobs above and the potential for savings for a longer training job.\n", + "\n", + "#### Actual Savings\n", + "\n", + "To get the actual savings, we use the describe_training_job API to get the billable seconds for each training job." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69e0d877", + "metadata": {}, + "outputs": [], + "source": [ + "# Billable seconds for the Native TensorFlow Training job\n", + "\n", + "details = sess.describe_training_job(job_name=native_estimator.latest_training_job.name)\n", + "native_secs = details[\"BillableTimeInSeconds\"]\n", "\n", - "HOPPER_IMAGE_URI='669063966089.dkr.ecr.us-west-2.amazonaws.com/pr-tensorflow-training:2.9.0-gpu-py39-cu112-ubuntu20.04-sagemaker-pr-1839-2022-05-17-00-38-02'\n", - "epochs=1\n", - "batch = 56\n", - "train_steps = int(30000*epochs/batch)\n", - "steps_per_loop = train_steps//10\n", - "overrides=\\\n", - "f\"runtime.enable_xla=False,\"\\\n", - "f\"runtime.num_gpus=1,\"\\\n", - "f\"runtime.distribution_strategy=one_device,\"\\\n", - "f\"runtime.mixed_precision_dtype=float16,\"\\\n", - "f\"task.train_data.global_batch_size={batch},\"\\\n", - "f\"task.train_data.input_path=/opt/ml/input/data/training/caltech*,\"\\\n", - "f\"task.train_data.cache=False,\"\\\n", - "f\"trainer.train_steps={train_steps},\"\\\n", - "f\"trainer.steps_per_loop={steps_per_loop},\"\\\n", - "f\"trainer.summary_interval={steps_per_loop},\"\\\n", - "f\"trainer.checkpoint_interval={train_steps},\"\\\n", - "f\"task.model.backbone.type=vit,\"\n", - "estimator = TensorFlow(\n", - " git_config={\n", - " 'repo': 'https://github.com/tensorflow/models.git',\n", - " 'branch': 'v2.9.2',\n", - " },\n", - " source_dir='.',\n", - " entry_point='official/projects/vit/train.py',\n", - " model_dir=False,\n", - " instance_type='ml.p3.2xlarge',\n", - " instance_count=1,\n", - " image_uri=HOPPER_IMAGE_URI,\n", - " hyperparameters={\n", - " TrainingCompilerConfig.HP_ENABLE_COMPILER : False,\n", - " 'experiment': 'vit_imagenet_pretrain',\n", - " 'mode' : 'train',\n", - " 'model_dir': '/opt/ml/model',\n", - " 'params_override' : overrides,\n", - " },\n", - " debugger_hook_config=None,\n", - " disable_profiler=True,\n", - " max_run=60*60*12, #12 hours\n", - " base_job_name='native-tf29-vit',\n", - " role=boto3.client('iam').get_role(RoleName='SageMaker-Execution-Role-For-PyTest')['Role']['Arn'],\n", - " )\n", - "estimator.fit(inputs='s3://collection-of-ml-datasets/Caltech-256-tfrecords')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2a6209ce", - "metadata": {}, - "outputs": [], - "source": [] + "native_secs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bdd3386", + "metadata": {}, + "outputs": [], + "source": [ + "# Billable seconds for the Optimized TensorFlow Training job\n", + "\n", + "details = sess.describe_training_job(job_name=optimized_estimator.latest_training_job.name)\n", + "optimized_secs = details[\"BillableTimeInSeconds\"]\n", + "\n", + "optimized_secs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88a9a736", + "metadata": {}, + "outputs": [], + "source": [ + "# Calculating percentage Savings from Training Compiler\n", + "\n", + "percentage = (native_secs - optimized_secs) * 100 / native_secs\n", + "\n", + "f\"Training Compiler yielded {percentage:.2f}% savings in training cost.\"" + ] + }, + { + "cell_type": "markdown", + "id": "f5661dfe", + "metadata": {}, + "source": [ + "#### Potential savings\n", + "\n", + "The Training Compiler works by compiling the model graph once per input shape and reusing the cached graph for subsequent steps. As a result the first few steps of training incur an increased latency owing to compilation which we refer to as the compilation overhead. This overhead is amortized over time thanks to the subsequent steps being much faster. We will demonstrate this below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fa7e90d", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(native_metrics[\"training_latency_per_epoch\"], label=\"native_epoch_latency\")\n", + "plt.plot(optimized_metrics[\"training_latency_per_epoch\"], label=\"optimized_epoch_latency\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "8c4e9288", + "metadata": {}, + "source": [ + "We calculate the potential savings below from the difference in steady state epoch latency between native TensorFlow and optimized TensorFlow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc3a33c1", + "metadata": {}, + "outputs": [], + "source": [ + "native_steady_state_latency = native_metrics[\"training_latency_per_epoch\"].iloc[-1]\n", + "\n", + "native_steady_state_latency" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7185dcfd", + "metadata": {}, + "outputs": [], + "source": [ + "optimized_steady_state_latency = optimized_metrics[\"training_latency_per_epoch\"].iloc[-1]\n", + "\n", + "optimized_steady_state_latency" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5459d0cb", + "metadata": {}, + "outputs": [], + "source": [ + "# Calculating potential percentage Savings from Training Compiler\n", + "\n", + "percentage = (\n", + " (native_steady_state_latency - optimized_steady_state_latency)\n", + " * 100\n", + " / native_steady_state_latency\n", + ")\n", + "\n", + "f\"Training Compiler can potentially yield {percentage:.2f}% savings in training cost for a longer training job.\"" + ] + }, + { + "cell_type": "markdown", + "id": "e7b06fc2", + "metadata": {}, + "source": [ + "### Convergence of Training\n", + "\n", + "Training Compiler brings down total training time by intelligently choosing between memory utilization and core utilization in the GPU. This does not have any effect on the model arithmetic and consequently convergence of the model.\n", + "\n", + "However, since we are working with a new batch size, hyperparameters like - learning rate, learning rate schedule and weight decay might have to be scaled and tuned for the new batch size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94cfeb5e", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(native_metrics[\"training_loss\"], label=\"native_loss\")\n", + "plt.plot(optimized_metrics[\"training_loss\"], label=\"optimized_loss\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "2bbd8835", + "metadata": {}, + "source": [ + "We can see that the model's convergence behavior is similar with and without Training Compiler. Here we have tuned the batch size specific hyperparameters - Learning Rate and Weight Decay using a linear scaling.\n", + "\n", + "Learning rate is directly proportional to the batch size:\n", + "```python\n", + "new_learning_rate = old_learning_rate * new_batch_size/old_batch_size\n", + "```\n", + "\n", + "Weight decay is inversely proportional to the batch size:\n", + "```python\n", + "new_weight_decay = old_weight_decay * old_batch_size/new_batch_size\n", + "```\n", + "\n", + "Better results can be achieved with further tuning. Check out [Automatic Model Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html) for tuning." + ] + }, + { + "cell_type": "markdown", + "id": "bc219a22", + "metadata": {}, + "source": [ + "## Clean up\n", + "\n", + "Stop all training jobs launched if the jobs are still running." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6f55df5", + "metadata": {}, + "outputs": [], + "source": [ + "def stop_training_job(name):\n", + " status = sess.describe_training_job(name)[\"TrainingJobStatus\"]\n", + " if status == \"InProgress\":\n", + " sm.stop_training_job(TrainingJobName=name)\n", + "\n", + "\n", + "stop_training_job(native_estimator.latest_training_job.name)\n", + "stop_training_job(optimized_estimator.latest_training_job.name)" + ] + }, + { + "cell_type": "markdown", + "id": "685e6c99", + "metadata": {}, + "source": [ + "Also, to find instructions on cleaning up resources, see [Clean Up](https://docs.aws.amazon.com/sagemaker/latest/dg/ex1-cleanup.html) in the *Amazon SageMaker Developer Guide*." + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "conda_tensorflow2_p38", "language": "python", - "name": "python3" + "name": "conda_tensorflow2_p38" }, "language_info": { "codemirror_mode": { @@ -157,7 +634,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.8.12" } }, "nbformat": 4,