diff --git a/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/bert-base-cased/bert-base-cased-single-node-single-gpu.ipynb b/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/bert-base-cased/bert-base-cased-single-node-single-gpu.ipynb index 3a4497c9d5..6c6656560b 100644 --- a/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/bert-base-cased/bert-base-cased-single-node-single-gpu.ipynb +++ b/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/bert-base-cased/bert-base-cased-single-node-single-gpu.ipynb @@ -50,7 +50,7 @@ "source": [ "## Introduction\n", "\n", - "This notebooks is an end-to-end binary text classification example. In this demo, we use the Hugging Face's `transformers` and `datasets` libraries with SageMaker Training Compiler to compile and fine-tune a pre-trained transformer for binary text classification. In particular, the pre-trained model will be fine-tuned using the Stanford Sentiment Treebank (SST) dataset. To get started, you need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on. \n", + "This notebook is an end-to-end binary text classification example. In this demo, we use the Hugging Face's `transformers` and `datasets` libraries with SageMaker Training Compiler to compile and fine-tune a pre-trained transformer for binary text classification. In particular, the pre-trained model will be fine-tuned using the `Stanford Sentiment Treebank (SST)` dataset. To get started, you need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on. \n", "\n", "![image.png](attachment:image.png)\n", "\n", @@ -81,7 +81,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"sagemaker>=2.108.0\" botocore boto3 awscli s3fs typing-extensions --upgrade" + "!pip install \"sagemaker>=2.108.0\" botocore boto3 awscli s3fs typing-extensions \"torch==1.11.0\" --upgrade" ] }, { @@ -112,7 +112,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Copy and run the following code if you need to upgrade ipywidgets for `datasets` library and restart kernel. This is only needed when preprocessing is done in the notebook.\n", + "Copy and run the following code if you need to upgrade \"ipywidgets\" for `datasets` library and restart kernel. This is only needed when preprocessing is done in the notebook.\n", "\n", "```python\n", "%%capture\n", @@ -134,7 +134,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Note:** If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for SageMaker. To learn more, see [SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html)." + "**Note:** If you are going to use SageMaker in a local environment. You need access to an IAM Role with the required permissions for SageMaker. To learn more, see [SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html)." ] }, { @@ -176,7 +176,7 @@ "\n", "If you'd like to try other training datasets later, you can simply use this method.\n", "\n", - "For this example notebook, we prepared the [SST2 dataset](https://www.tensorflow.org/datasets/catalog/glue#gluesst2) in the public SageMaker sample S3 bucket. The following code cells show how you can directly load the dataset and convert to a HuggingFace DatasetDict." + "For this example notebook, we prepared the [SST2 dataset](https://www.tensorflow.org/datasets/catalog/glue#gluesst2) in the public SageMaker sample S3 bucket. The following code cells show how you can directly load the dataset and convert to a `HuggingFace DatasetDict`." ] }, { @@ -406,7 +406,7 @@ "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", - "hyperparameters = {\"epochs\": 5, \"train_batch_size\": 14, \"model_name\": \"bert-base-cased\"}\n", + "hyperparameters = {\"epochs\": 5, \"train_batch_size\": 16, \"model_name\": \"bert-base-cased\"}\n", "\n", "# Scale the learning rate by batch size, as original LR was using batch size of 32\n", "hyperparameters[\"learning_rate\"] = float(\"5e-5\") / 32 * hyperparameters[\"train_batch_size\"]\n", @@ -712,7 +712,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Plot and compare throughputs of compiled training and native training" + "### Plot and compare throughput of compiled training and native training" ] }, { @@ -765,7 +765,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Example output for SageMaker Training Compiler traing job\n", + "#### Example output for SageMaker Training Compiler training job\n", "\n", "{'train_runtime': 3742.9028,\n", " 'train_samples_per_second': 89.969,\n", @@ -801,27 +801,6 @@ "plt.xticks(ticks=[1, 1.5], labels=[\"Baseline PT\", \"SM-Training-Compiler-enhanced PT\"])" ] }, - { - "attachments": { - "throughput.png": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAElCAYAAAD+wXUWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de7xd853/8ddbQiOCiITGJaJiqrQoKUGrTFNKFTNl6FBpS7VTner82qpWR3Xqml4U0ypliLq2qbp1Ki6Dxl1CSAilLhFS4hLi0gr5/P74fresbN9zzt4nZ5+dnLyfj8d5nLXX+q61PmvfPvu7Lp+liMDMzKzeCu0OwMzMlk5OEGZmVuQEYWZmRU4QZmZW5ARhZmZFThBmZlbkBGFLHUn9JL0iaURPtu1JkkZJWirOEZfUX1JIGtnuWKxvcYKwJZa/oGt/CyW9Xnl8QLPLi4i3ImJQRMzqybbNkHRWZRvekLSg8vjKnlzX0mxpSoTW++QL5awnSXocOCQiruukTf+IeLP3oloyko4F1ouIz1XGjQIejgg1uIyWbbOk/sACYMOIeLyHl93Udlrf4h6EtZykYyVdIukiSfOBAyVtJ+l2SfMkzZF0qqQVc/vFdplIOj9P/6Ok+ZJuk7Rhs23z9N0k/VnSS5JOk3SLpM8twbYdJGm2pLmSjuximwfk2OZIekrSTyWtlNsfIunGyvz12zVM0h8kvSzpTknHV9tnu0p6RNKLkk6tLOsQSX+S9Iu83TMl7VyZPlvSTnWxn5sf/imPq/WePtTd58qWPU4Q1lv+CbgQWB24BHgTOBwYCuwAfAL4Uifz/yvwn8AQYBbww2bbSloL+A3wrbzex4BturtB2fbAKGBX4AeSNq5Mq9/mo4HRwObAB0nb/Z0G13M6MA9YG/gCMK7QZndg67zsAyWNrYvzQdJ2/xD4vaTBDax3R4C8G29QRNzVYLzWBzhBWG+5OSKujIiFEfF6RNwVEXdExJsR8ShwJvDRTuafGBFTImIBcAGwZTfa7gFMi4jL87STgeeWcLuOiYi/RcTdwP3AFpVpi20zcEBuPzcingX+C/hsVyvIPau9gaPzczcD+HWh6QkR8VLezXQjiz9Hc4DTImJBRFwIPArs1vTW2nLFCcJ6y5PVB5I2ybtM/irpZdKX5dBO5v9rZfg1YFA32q5TjSPSAbjZDcTeoYjoLK4n65oPB56oPH4CWLeB1awN9KtbXv2yofPnaHYsfsDxCdLzYdYhJwjrLfVnQ5wBzABGRcRqpN0vrT4QOgdYr/ZAkmjsC7q76rd5DrBB5fEI4Kk8/CowsDLt3ZXhZ4CFVGIH1m8ylvXqHo8Anm5g3T6LZTnmBGHtsirwEvCqpPfR+fGHnnIVsJWkT+Uzfw4HhvXCemsuAo6WNFTSMNJxkvPztHuBzSV9QNLKwPdrM+XdYZeRjnGsLGkz4MAm1z1c0lfzwe/9gY2Aq/O0acD+edo2wD9X5nsWCEnvaXJ91gc4QVi7fIN0oHU+qTdxSatXGBHPAPsBPwWeJ31J3gP8vdXrzn5ASgTTgfuAO4ATcmwPAMeTjh08RD57qOLfgDVJvYlzSMmmmbhvBTYDXgCOAT4dES/maUcBm5AOgv8n6cA6Oa75OcY78hlno5tYpy3jfB2ELbck9SPtZtknIia3O55mSPoJMDgiDm6g7SHAgRGxU8sDsz7FPQhbrkj6hKTVJb2L9Gv5TeDONofVJUmb5t1PkjQG+Dzw+3bHZX1b/3YHYNbLPkw69XUl0mmpe0dEb+1iWhKrkeIeTtrNdGJEXNXekKyv8y4mMzMr8i4mMzMrcoKwtsq1ivZuov0Bkq5psO3nJN3c/ei6XP53JZ3VjfkeryuD0VaS3iXpwVyKxOxtThDWNpI2J5WmuDw/Ln6hV79QI+KCiNildyMti4jjI+KQ3lxnfXKRNDIX9ev28cR8DOZ/gG/3RIzWdzhBWDt9CbggloEDYUvyBbw0q2zXhcC4fHaXGeAEYe21G3BTMzPU9zIk7SLpoVzG+heSbsrn/Vfn+XEugf2YpN0q41eXdHal/Pax+dqI2npukXSypNrFZfWxHCPp/Dw8QKnU+PP5grK7JK3dyaZ8SNIDOa5zJA2oLHcPSdPycm7NPS0k/ZpUIuPKXHr7CBZdUDcvj9sut/1CLuv9oqRJkjaoLD8kHSbpYeBhgIiYDbwIjOn6VbDlhROEtYWkVYANSVcNd3cZQ4GJpJLZa+ZlbV/XbNs8figwHjg712ACmEC6DmIUqUT2LsAhdfM+CqwFHNdFOONIZb3Xz7F8GXi9k/YHkEqEbwT8A/C9vE1bkXb3fCkv5wzgCknviojPksqXfyqX3h5PLsdNumhuUETclo/pfJdUMmMYMJl05XXV3nn7Nq2Mm8ni1WhtOecEYe1SuxfB/LrxY/Iv57f/SL+aS3YH7o+IS/Pd2k5l8YqmAE9ExK8i4i1SQhgOrJ1/3e8GfD0iXs3lt08G9q/M+3REnJZLknf2ZQ/pjm5rkooPvhURUyPi5U7a/3dEPBkRL5CSz2fy+C8CZ+RS6G9FxARSSY1mftl/iVT6e2Z+Xo4Htqz2IvL0F+q2az6LXhczJwhrm3n5/6p142+PiMHVP9Kv5pJGynf/tTL9tTw4iFRVdUVgTiURnUHqLdSUSmp35NfAJOBiSU9LGq98h7wOVJddLb29AfCNugS5Ps2V5t4AOKUy/wukSrnVyrWlbVuVRa+LmROEtUdEvAr8hbR7pbtK5bvry1p35EnSL/OhlWS0WkRsVg2z0UDyjXh+EBGbknZz7QEc1Mks1XLd1dLbTwLH1SXJgRFR20VUH1MpxieBL9UtY+WIuLWL+d5HKiZoBjhBWHv9L53fRa4rfwA+IGnvfDbOYSx+L4MORcQc4BrgJ5JWk7SCpI0kdSseSTvnWkn9gJdJu5ze6mSWwyStJ2kI6XhBrZrtr4AvS9o2111aRdInJdV6Ws8A1dLbc0n3iqiO+yXwHaWy4LWD8ft2Ef+6pFu03t7QBttywQnC2ulM4IDKQeOmRMRzwL6kg8/Pkw64TqHxMtgHkWoyPUA6g2ci6RhFd7w7z/8y6WDvTSy610PJhaQE9Wj+OxYgIqaQjkP8d47pEeBzlflOAL6Xdx99M+82Ow64JY8bExG/B04i7e56mXRjpq5uL/qvwIRlpC6V9RLXYrK2knQh8JuIuKwHlrUC6RjEARFxwxIHt5zI1z7cC+yYD9abAU4QtoyTtCvpxjuvA98i7WZ6TwNnHZlZF7yLyZZ125EOdj8HfIpUvtvJwawHuAdhZmZF7kGYmVnRMlGAbOjQoTFy5Mh2h2FmtkyZOnXqcxExrLvzLxMJYuTIkUyZMqXdYZiZLVMkPbEk83sXk5mZFTlBmJlZkROEmZkVOUGYmVmRE4SZmRU5QZiZWZEThJmZFTlBmJlZkROEmZkVLRNXUpsZcMzq7Y7AetsxL7V19e5BmJlZkROEmZkVOUGYmVmRE4SZmRU5QZiZWZEThJmZFTlBmJlZkROEmZkVOUGYmVmRE4SZmRU5QZiZWZEThJmZFTlBmJlZkROEmZkVOUGYmVmRE4SZmRU5QZiZWZEThJmZFTlBmJlZUUsThKT/kHS/pBmSLpI0QNKGku6Q9LCkSySt1MoYzMyse1qWICStC3wNGB0R7wf6AfsDJwEnR8TGwIvAwa2KwczMuq/Vu5j6AytL6g8MBOYA/whMzNMnAHu3OAYzM+uGliWIiHgK+DEwi5QYXgKmAvMi4s3cbDawbml+SYdKmiJpyty5c1sVppmZdaCVu5jWAPYCNgTWAVYBdis0jdL8EXFmRIyOiNHDhg1rVZhmZtaBVu5iGgs8FhFzI2IBcCmwPTA473ICWA94uoUxmJlZN7UyQcwCxkgaKEnAx4AHgBuAfXKbccDlLYzBzMy6qZXHIO4gHYy+G5ie13Um8G3g/0l6BFgTOLtVMZiZWff177pJ90XE94Hv141+FNimles1M7Ml5yupzcysyAnCzMyKnCDMzKzICcLMzIqcIMzMrMgJwszMipwgzMysyAnCzMyKnCDMzKzICcLMzIqcIMzMrMgJwszMipwgzMysyAnCzMyKnCDMzKzICcLMzIqcIMzMrMgJwszMipwgzMysyAnCzMyKnCDMzKzICcLMzIqcIMzMrMgJwszMipwgzMysyAnCzMyKnCDMzKzICcLMzIqcIMzMrMgJwszMivq3O4BWG3nkH9odgvWyx0/8ZLtDMOsT3IMwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrammCkDRY0kRJD0qaKWk7SUMkXSvp4fx/jVbGYGZm3dPqHsQpwNURsQmwBTATOBK4PiI2Bq7Pj83MbCnTsgQhaTVgR+BsgIh4IyLmAXsBE3KzCcDerYrBzMy6r6EEIWkVSStUHq8gaWAXs70HmAucI+keSWdJWgVYOyLmAOT/a3WwzkMlTZE0Ze7cuQ1tjJmZ9ZxGexDXA9WEMBC4rot5+gNbAadHxAeBV2lid1JEnBkRoyNi9LBhwxqdzczMekijCWJARLxSe5CHu+pBzAZmR8Qd+fFEUsJ4RtJwgPz/2eZCNjOz3tBognhV0la1B5K2Bl7vbIaI+CvwpKT35lEfAx4ArgDG5XHjgMubitjMzHpFo9Vcvw78VtLT+fFwYL8G5vt34AJJKwGPAp8nJaXfSDoYmAXs21zIZmbWGxpKEBFxl6RNgPcCAh6MiAUNzDcNGF2Y9LGmojQzs17X6FlMA4FvA4dHxHRgpKQ9WhqZmZm1VaPHIM4B3gC2y49nA8e2JCIzM1sqNJogNoqI8cACgIh4nbSryczM+qhGE8QbklYGAkDSRsDfWxaVmZm1XaNnMX0fuBpYX9IFwA7A51oVlJmZtV+jZzFdK+luYAxp19LhEfFcSyMzM7O2avQsph2Av0XEH4DBwHclbdDSyMzMrK0aPQZxOvCapC2AbwFPAOe1LCozM2u7RhPEmxERpFLdp0bEKcCqrQvLzMzardGD1PMlfQc4ENhRUj9gxdaFZWZm7dZoD2I/0mmtB+cifOsCP2pZVGZm1nad9iAkTSKd3vrHiPhpbXxEzMLHIMzM+rSuehDjgBeBYyTdLel0SXtJGtQLsZmZWRt12oPIu5POBc7NtxzdFtgNOELS68A1uQSHmZn1MY0epCYiFgK35b+jJQ0Fdm1VYGZm1l6NXig3XtJqklaUdL2k54BPRMQFLY7PzMzapNGzmHaJiJeBPUilvv+BdMGcmZn1UY0miNo1D7sDF0XECy2Kx8zMlhKNHoO4UtKDwOvAVyQNA/7WurDMzKzdGupBRMSRpLvJjc73on6NVHbDzMz6qGbuSX0YqWgfwDrA6FYFZWZm7dfsPam3z499T2ozsz7O96Q2M7Mi35PazMyKfE9qMzMr8j2pzcysqKty31vVjZqT/4+QNCIi7m5NWGZm1m5d9SB+0sm0AP6xB2MxM7OlSFflvnfurUDMzGzp0tAxCEkDgK8AHyb1HCYDv4wIl9swM+ujGj2L6TxgPnBafvwZ4NfAvq0IyszM2q/RBPHeiNii8vgGSfe2IiAzM1s6NHqh3D2SxtQeSNoWuKU1IZmZ2dKg0R7EtsBBkmblxyOAmZKmAxERm7ckOjMza5tGE8QnWhqFmZktdRq9kvoJSWsA61fn8YVyZmZ9V6Onuf6QVHvpL+SCffhCOTOzPq3RXUz/Qir5/UazK5DUD5gCPBURe0jaELgYGALcDXy2O8s1M7PWavQsphnA4G6u43BgZuXxScDJEbEx8CJwcDeXa2ZmLdRogjiBdKrrJElX1P66mknSesAngbPyY5F2S03MTSYAezcftpmZtVqju5gmkH75TwcWNrH8nwFHAKvmx2sC8yLizfx4NrBuaUZJhwKHAowYMaKJVZqZWU9oNEE8FxGnNrNgSXsAz0bEVEk71UYXmkZhHBFxJnAmwOjRo4ttzMysdRpNEFMlnQBcQeVWo12c5roDsKek3YEBwGqkHsVgSf1zL2I94OluRW5mZi3VaIL4YP4/pjKu09NcI+I7wHcAcg/imxFxgKTfAvuQzmQaB1zeZMxmZtYLGr1QrifvC/Ft4GJJxwL3AGf34LLNzKyHNNqDQNIngc1Iu4sAiIj/amTeiLgRuDEPPwps00yQZmbW+xo6zVXSL4H9gH8nHWjeF9ighXGZmVmbNXodxPYRcRDwYkT8ANiOVJfJzMz6qEYTxOv5/2uS1gHeBDZsTUhmZrY0aPQYxFWSBgPjgal53FmtCcnMzJYGnSYISR8CnoyIH+bHg0hXUz8InNz68MzMrF262sV0BvAGgKQdgRPzuJfIVzmbmVnf1NUupn4R8UIe3g84MyJ+B/xO0rTWhmZmZu3UVQ+in6RaEvkY8H+VaQ1fQ2FmZsuerr7kLwJukvQc6UymyQCSRpF2M5mZWR/VaYKIiOMkXQ8MB66JiFpV1RVIF82ZmVkf1eVuooi4vTDuz60Jx8zMlhaNXihnZmbLGScIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMralmCkLS+pBskzZR0v6TD8/ghkq6V9HD+v0arYjAzs+5rZQ/iTeAbEfE+YAxwmKRNgSOB6yNiY+D6/NjMzJYyLUsQETEnIu7Ow/OBmcC6wF7AhNxsArB3q2IwM7Pu65VjEJJGAh8E7gDWjog5kJIIsFYH8xwqaYqkKXPnzu2NMM3MrKLlCULSIOB3wNcj4uVG54uIMyNidESMHjZsWOsCNDOzopYmCEkrkpLDBRFxaR79jKThefpw4NlWxmBmZt3TyrOYBJwNzIyIn1YmXQGMy8PjgMtbFYOZmXVf/xYuewfgs8B0SdPyuO8CJwK/kXQwMAvYt4UxmJlZN7UsQUTEzYA6mPyxVq3XzMx6hq+kNjOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzIicIMzMrcoIwM7MiJwgzMytygjAzsyInCDMzK3KCMDOzorYkCEmfkPSQpEckHdmOGMzMrHO9niAk9QN+DuwGbAp8RtKmvR2HmZl1rh09iG2ARyLi0Yh4A7gY2KsNcZiZWSf6t2Gd6wJPVh7PBratbyTpUODQ/PAVSQ/1Qmx9zVDguXYH0dt0UrsjsBZYLt/L/EBLuoQNlmTmdiSI0hbHO0ZEnAmc2fpw+i5JUyJidLvjMFtSfi+3Rzt2Mc0G1q88Xg94ug1xmJlZJ9qRIO4CNpa0oaSVgP2BK9oQh5mZdaLXdzFFxJuSvgpMAvoB/xMR9/d2HMsJ76KzvsLv5TZQxDt2/5uZmflKajMzK3OCMDOzIieIXiTpLUnTJN0r6W5J2/fw8s+VtE8ePqsnrlCXtJOklyTdI2mmpO9L2jVvxzRJr+SyKdMknbfkW7HsknSUpPsl3Zefj23z+BslzZKkStvLJL3SwTJqz+1bleGvNRHHtpJO7qJNP0mTm9m+Lpa3oqTxuXzODEl3SNq1p5ZfWN/b8UsaJWlaq9bVwfr7S5rXm+vsiKTZkgZ3MH56fj9eLWktSVPy+2mWpLmV99f6pWW34zqI5dnrEbElQP7wnAB8tBUriohDenBxkyNiD0mrANOAqyrbcSPwzYiY0oPrW+ZI2g7YA9gqIv4uaSiwUqXJPGAH4Ob8YR5eWk5EHAccl5f5Su15Lqyvf0S82cEy7gDu6CzeiHgL+EjnW9WUE4AhwKYR8Yak4aTtbYklib+z564P+khEzJM0Hjiydi2JpEOA90fE1zub2T2I9lkNeBFA0iBJ1+dexXRJe+Xxq0j6Q+5xzJC0Xx6/taSbJE2VNCl/GBeTf7XW3gyvSDouL+d2SWvn8cMk/U7SXfmv0w90RLwKTAU26tFnom8YDjwXEX8HiIjnIqJ6fc/FpFO6Af4ZuLTZFUg6X9JPJN0AHC9pjKTbcu/uFkkb53ZjJV2Wh4+VdHZ+vzwq6bA8/u1fwLn99ZIuzb3B8yrr3DOPmyzptNpy6+JaFfgc8LVcPoeImBMRE/P0A/P7eoak46vrl/Sj/L6flHs+tTh3z+0OkfT7PP0hSd+rj78ulv6SfirpzvzL+ZDKNl4n6WLgnsJ8g5R64Hfm5/NTlfVPzOt/WNIJdfOdmD9Xt0laK4/bS6kHdY+kayrji69Fnvb5HO+9ks7J49bOr8mUHNeYPH6YpGvz83Y65YuP6/0JGNVAu8VFhP966Q94i/QL/EHgJWDrPL4/sFoeHgo8kl/0TwO/qsy/OrAicCswLI/bj3SqMMC5wD55+EZgdB4O4FN5eDzwvTx8IfDhPDwCmFmIeSdSjwFgTeBxYLPK9LfXszz/AYPya/tn4BfAR+ueo22B+0indl8DjARe6WKZr9Q9Ph+4DFih8n7ol4c/AVySh8cCl+XhY4HJpN7MWsDzOYb+wLxK+xdJSa4f6VqlMcBA0oWtG+T3429ry62Layvgrg62Yb38nhma37s3kXpa/fP78uO53ZXAH/P4rYEpefwhwFPAGsAqwAPAlnXxjwKm5eGvkH4pA7yLlAxG5G18BRjRQZzjgf3z8Br5dRyQ1/8wsCqwMqlM0DqV+HfL8/y0st41WHSG6JeBk7p4LbYgfScMye1q/y8BxuThkcCMPPwL4Lt5eK8cx+DCNs0GBufX7pfAcZVphwA/6+p97V1Mvau6i2k74DxJ7ye9gMdL2hFYSKpXtTYwHfixpJNIX9KTc/v3A9cq7dLuB8zpYr1vAFfl4anAx/PwWGBTLdo1vpqkVSNift38H5F0T47txPB1K+8QEa9I2pq022Nn4BJJR0bEubnJW8DNpIS+ckQ8Xnnem/HbiFiYhweT3kNd9eiuivTL/llJLwDDeGddo9sjYg6A0v78kcCbwEMR8UQefxFwUJPxbgv8X0Q8l5dxIbAjcDXp83BtbjcdeCnSdVLT8/prJkVErbd9GfBhYEYH69sFeJ+kWm9tdWDjPHxbRMzqZL7dtOj2AwNIiQXgutpnQtKDefyzOf4/5jZTWbTLawTwG0nvJiWpP1fWU3ot/pGU3F8AqP0nfT7fW3mfrCFpZdLzt3tue7mk+s9r1WTS53Ya0HSVMieINomI25T2Uw8jvdjDSD2KBZIeBwZExJ/zl87uwAmSrgF+D9wfEds1sboFkX82kL6oaq/7CsB2EfF6F/NPjog9mljfcinSfvEbgRvzl9w4Uq+u5mLS63dMdT5JxwGfzMsoHnOoeLUyfBzpy/MXkkaRvnRL/l4Zrr7+XbXpMINJuo7UK7gd+BawoaRVIu2GXKxpR8sg/XCpWViJYWFdjPUXa3V28ZaAr0TE9XXxjqXy3Ckd9P9CfrhLnm/viPhL3Xw70vHz90YH438OHB8R/5vXW73nTUfPc2mbBGyTE0o1JjpoX/KRiOj2wXQfg2gTSZuQfv0/T/qV82xODjuTKzBKWgd4LSLOB35M6so/BAzLPZDa2SObdTOMa4CvVmLq6svJOiDpvcrHALItgSfqmk0mHcy9qDoyIo6KiC0bSA71ViftfoF0DKCn3U/6Bbu+0rfSfrUJETE2x/zl/Ov6POBnklaE9N6VdAApgewsaU1J/UnHYW5qMo5dJA2WNJC0S+WWTtpOAr6S11V7XVaubxQRp9ae84h4Ns/39plikj7YZIxVqwNP5edsXAPtrwP2lzQkr3tIZXz1OEXt/fEn4IA87lOk3V8t4QTRu1ZWPq2MtH9xXP7VeQEwWtIU0gv/YG7/AeDO3P4o4Nj8a2If4CRJ95K6jt09XfZreb33SXqAtL/UumcQMEHSA5LuI90M65hqg0h+XNvd0gNOAn4kqbMvzG6LiNdIPyCuIyW3p0nHzkqOzNNm5t7TpaQfPbOBo0k9q2mkXVl/aDKUm0nHy+4BLoqIzk5pPYN0zGCapBnA6TS2p+QHwEClg+n3U/faNekYUk/xJuCZrhpHxH2kYyB/yp/1H+VJhwE7VD6fX8zjvw+MlXQ36RjhU7SIS22YWYckDcrHV0T68p0eEaf14vobOh3TWsM9CDPrzL/lX7UPkM7i+VWb47Fe5B6EmZkVuQdhZmZFThBmZlbkBGFmZkVOEMsxSVfkUwFrj4fkGi8P5/9r5PGfVqpSOlnSmnncRkp1bTpa9he0qJLkDC2qL3WupNeU6vfU2p4iKfKFg9VlfF6Lqk2+kZc3TdKJTWzj+pIuaaDdpGpMS0LJEUq1g2bkmA/oiWV3ss5JklbVUlJlVNJhrd5maz0fpO4DlKqsLqi/4rKLef6ZdD3F5hHx/jxuPPBCRJyoVHJgjYj4tqRbgV1JFzkNiIjTlMouHB0RDxeWvR7pHPCtIuIlSYNItaMek3Qu6YK/8RFxvqQVSOfHDwG27OgaAaWry0eXpmspq86pdEvdTwL/EhHzlaq37hkRLS+Hni8Qey4i3lH+uYP2In0PLOyy8TJA0pBKqQpbQu5B9A3/ADykVOnzfV01zl/Y/49UPKxqL2BCHp4A7J2HF5JqygwEFkj6CDCnlByytYD5pOJoRMQrEfFYZfpFLLoqdyfSlbFNfcErVcY8Q9K1wDm5RzNZqYLmVC26F8Pb9wpQJ5U5lWvq5/YzlKpu3i/pj5IG5DZjco/oVqUqpB1dsPVdoHaFMRExr5YcJH089yimS/qVpJUq6z9OqdruXZK2UqoE+hdJX8xtxkq6QeleEg9I+nn+gu/sngBHalFl06Mrz8kMSb8E7qau9HjetgfyPCflce+oUKp0T4YnJK2W20ipSunQ/Pp8PY+/Wanq6Z25V7V9Hr+KUjXheyVdpFS1dMvcC/q1FlWA7fReGJJWVqoYeyOpaJ71lK6q+fmvyyqeJ6FsGdAAAAXISURBVJN+Adf/1So77tzB9Fsry5jcQZuxefq3Oph+amUZq5IqNN5CuvL088AqncT8T1QqRObx8+ravZj/f5xUjOxKUhmBSaTeRUfPSb/cZhZwDrmSbJ52Lqnncjup6uWvSPfEeBwY2skyF5tOSm53kno0kJJXbXgT4I48XK30WazMmafVKl+OAhYAH8jjL2VRlc+ZpNo4kEqfTCvEuQYwt4NtGJjXuVF+fAHw1cr6v5iHTyNdNbwKqWjjX/P4scBr+XXrB/wfqX5QNf5qldPdSZU/RfoxeDXpqvtRpKT/oUKMa5NKbNT2LgyubFepQunPgc/m4R2Aqyuvz9fz8M2V9ntW2hwJ/DwPb0GqTbQlqcDfHysxvaNSaR7/wbz+v+TnbMvKtE0pf2amAavmNhM7mH5Ann5QB9NrVXMHd7KO97b7u6kn/toegP9a8KKmD8etwMuFaVsCV+bhkTSQIOrGjQMOJ5WDnkj6gh9YaCdgG+A7pPLlx+Tx55ISxBHAv5FKYK9A9xLEUZXHa5C+cGfkD+j8PL4+QZxemedaFpVTriaImZU2R+UvsqHAXyrjt6KcIIbQcYLYmlTZtPZ4V+A3lfWvnYcPrYvzaVIpj7F18x8K/Lgu/mqC+BnwWOVL6xFSzaZRwMMdxLgiqbLqWaQfESvm8Vvk52s6qTpprQT8jpXh04DPV16faoLYNg+vCzyYh68iFZOrrfs+0vtzTeBR4JT8HKkQ5xHA66Ty3u9q92eur/65musSUrq1486FSRdH2pe/M+kXe73XIqLW1Z5MueDWNyPiOknfIhfnqvOniKgWGNuA9AXwGeBeyvVktgO2zvv0+wNrSboxInYCnpE0PCLmKN2E6Nm6bR1IShC7kgr97QX8a45tsStsI32K7yTVkrqW1JOoxnMxaffGhIhYmPeUoHQTlVrNmd1j8Zvu1KtWDv0G6df5gaQvuXfczjNrRWXT84DNgVkRsaekBZJGxDtLS3dV37tazbQaQ7W6abOVTY+NiLPr4h3F4s/dooWlgpGjSb3G/UlJfBc6rlA6GThX6eSFPYH/7GLbqs958fmIiOclbQ7sRqoX9mlSMqyaQOpFHQZ8TOkmO1dHPhaldLvdCzuI5SORjg1NpHwTnR9FxAWSDiLtiq33UETsl3fp3djBOvaLiIc6mLbMcIJYQhHxH11Mv4H0q6izNp3eOjEifsSiAl7vIGkk6RffUNIX8Q4R8XwHyzqdVMCsNt9VOTkAXEFKACfm/5fXzX4EcEr+ElmZ9OW0kLTrpBrPOsC7I+LuPOodlU0jYpako0iF4Krjf076MmrW6sAjERGSxtHYXbYaFhFz8xf/6Ei3V92/Mq3+HgknAr+Q9JlYdJB6X1IPZ2NJ74mIR0nJrNnKpmMkjSAVaPsX0q/2jkwCvifp4oh4Venkgb91tnClM7kGRMRVku4gldiADiqU5uf7clJv5d5orrT0zXkbJkv6AKnni6RhwN8i4reSHiPd7GYxEfEMqTLuCZI+ChwMnCLp1Ig4JSJqNxbqUETs08X080hVajuaPq+rdSzrnCD6hrdId5i6cwmXcyLpRicHk44f7FubkL/0R0fEMXnUT0jHEeax6GB2zYqkGx2tQ/pCmkuhUmxEnLGE8Vb9NzBR0mdISefvXbTvji+QDojPJ5Vc7qiy6Wmk4wdTJb1BOqYxPiJey8/tpZL6ke4b3Wxto1tJz/1mpF+vV3TUMP/a3wS4PffQ5pN6fJ1ZPcf3LtKuv9ov6GNIFUpnk3qG1QPblwC3kRJeM04j3fDoPlJvcgbpOV0fODsnowC+3dlCIuIm4CZJqwOjm4zBOuHTXM0apFzZNA8fRbo15Dd6cf1jSQe16xPyMknplNz+EfE3pXtpXANsHEvRKcvLO/cgzBq3p6QjSJ+bx2nNTXqWJ4OA63OiEPAlJ4eli3sQZmZW5AvlzMysyAnCzMyKnCDMzKzICcLMzIqcIMzMrOj/A/rLJSMoV+ZEAAAAAElFTkSuQmCC" - } - }, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training Throughput Example Plot\n", - "\n", - "![throughput.png](attachment:throughput.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Note:** For this example, the compiler delivers higher throughput for an ML model as measured by samples per second. However, you might not see an improvement in the total training time for your model. The total training time depends on several other factors, such as key components of the Trainer and TFTrainer APIs." - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/roberta-base/roberta-base.ipynb b/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/roberta-base/roberta-base.ipynb index 84c5335500..ae6f16b5c6 100644 --- a/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/roberta-base/roberta-base.ipynb +++ b/sagemaker-training-compiler/huggingface/pytorch_single_gpu_single_node/roberta-base/roberta-base.ipynb @@ -67,7 +67,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"sagemaker>=2.108.0\" botocore boto3 awscli --upgrade" + "!pip install \"sagemaker>=2.108.0\" botocore boto3 awscli \"torch==1.11.0\" --upgrade" ] }, { @@ -99,7 +99,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Copy and run the following code if you need to upgrade ipywidgets for `datasets` library and restart kernel. This is only needed when prerpocessing is done in the notebook.\n", + "Copy and run the following code if you need to upgrade `ipywidgets` for `datasets` library and restart kernel. This is only needed when prepocessing is done in the notebook.\n", "\n", "```python\n", "%%capture\n", @@ -164,7 +164,7 @@ "\n", "If you'd like to try other training datasets later, you can simply use this method.\n", "\n", - "For this example notebook, we prepared the `SST2` dataset in the public SageMaker sample file S3 bucket. The following code cells show how you can directly load the dataset and convert to a HuggingFace DatasetDict." + "For this example notebook, we prepared the `SST2` dataset in the public SageMaker sample file S3 bucket. The following code cells show how you can directly load the dataset and convert to a `HuggingFace DatasetDict`." ] }, { @@ -302,7 +302,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Set up an option for fine-tuning or full training. Set `FINE_TUNING = 1` for fine-tuning and using `fine_tune_with_huggingface.py`. Set `FINE_TUNING = 0` for full training and using `full_train_roberta_with_huggingface.py`." + "Set up an option for fine-tuning or full training. `FINE_TUNING = 1` is for fine-tuning and it will use `fine_tune_with_huggingface.py`. `FINE_TUNING = 0` is for full training and it will use `full_train_roberta_with_huggingface.py`." ] }, { @@ -318,7 +318,7 @@ "FULL_TRAINING = not FINE_TUNING\n", "\n", "# Fine tuning is typically faster and is done for fewer epochs\n", - "EPOCHS = 4 if FINE_TUNING else 100\n", + "EPOCHS = 7 if FINE_TUNING else 100\n", "\n", "TRAINING_SCRIPT = (\n", " \"fine_tune_with_huggingface.py\" if FINE_TUNING else \"full_train_roberta_with_huggingface.py\"\n", @@ -340,7 +340,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The `train_batch_size` in the following code cell is the maximum batch that can fit into the memory of an `ml.p3.2xlarge` instance. If you change the model, instance type, sequence length, and other parameters, you need to do some experiments to find the largest batch size that will fit into GPU memory." + "The `train_batch_size` in the following code cell is the maximum batch that can fit into the memory of the `ml.p3.2xlarge` instance. If you change the model, instance type, sequence length, and other parameters, you need to do some experiments to find the largest batch size that will fit into GPU memory." ] }, { @@ -628,9 +628,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Plot and compare throughputs of compiled training and native training\n", + "### Plot and compare throughput of compiled training and native training\n", "\n", - "Visualize average throughputs as reported by HuggingFace and see potential savings." + "Visualize average throughput as reported by HuggingFace and see potential savings." ] }, {