Skip to content

Commit

Permalink
Interactive 3D visualization of tensors (aws#339)
Browse files Browse the repository at this point in the history
* adding files for tensor visualization

* updated setup.py and tensor_plot.py

* Updated min and max ranges in tensor_plot.py

* fixed tensorplot.py

* refactored TensorPlot class

* refactored TensorPlot class

* updated notebook

* updated notebook

* refactored code and updated notebook

* Update to use new tensors APIs

* Remove sys.path prints

* fixed bug in notebook

* Use the correct path again
  • Loading branch information
NRauschmayr authored and rahul003 committed Nov 13, 2019
1 parent 6514203 commit 53ed26e
Show file tree
Hide file tree
Showing 3 changed files with 644 additions and 0 deletions.
313 changes: 313 additions & 0 deletions examples/mxnet/notebooks/mxnet-tensor-plot.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualizing Tensors "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Overview"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tornasole is a new capability of Amazon SageMaker that allows debugging machine learning models. \n",
"It lets you go beyond just looking at scalars like losses and accuracies during training and gives \n",
"you full visibility into all the tensors 'flowing through the graph' during training. Tornasole helps you to monitor your training in near real time using rules and would provide you alerts, once it has detected an inconsistency in the training flow.\n",
"\n",
"Using Tornasole is a two step process: Saving tensors and Analysis. In this notebook we will run an MXNet training job and configure Tornasole to store all tensors from this job. Afterwards we will visualize those tensors in our notebook.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dependencies\n",
"Before we begin, let us install the library plotly if it is not already present in the environment.\n",
"If the below cell installs the library for the first time, you'll have to restart the kernel and come back to the notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install plotly"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and run the training job\n",
"\n",
"Now we'll call the Sagemaker MXNet Estimator to kick off a training job along with the VanishingGradient rule to monitor the job.\n",
"\n",
"The 'entry_point_script' points to the MXNet training script that has the TornasoleHook integrated.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"entry_point_script = '../scripts/mnist_gluon_save_all_demo.py'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"import os\n",
"import sagemaker\n",
"from sagemaker.mxnet import MXNet\n",
"\n",
"\n",
"REGION='us-west-2'\n",
"TAG='latest'\n",
"\n",
"docker_image_name= '072677473360.dkr.ecr.{}.amazonaws.com/tornasole-preprod-mxnet-1.4.1-cpu:{}'.format(REGION, TAG)\n",
"\n",
"estimator = MXNet(role=sagemaker.get_execution_role(),\n",
" base_job_name='mxnet-trsl-test-nb',\n",
" train_instance_count=1,\n",
" train_instance_type='ml.m4.xlarge',\n",
" image_name=docker_image_name,\n",
" entry_point=entry_point_script,\n",
" framework_version='1.4.1',\n",
" debug=True,\n",
" py_version='py3')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Start the training job:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get S3 location of tensors\n",
"\n",
"We can check the status of the training job by running the following command:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"job_name = estimator.latest_training_job.name\n",
"\n",
"client = estimator.sagemaker_session.sagemaker_client\n",
"\n",
"description = client.describe_training_job(TrainingJobName=job_name)\n",
"\n",
"print('downloading tensors from training job: ', job_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can retrieve the S3 location of the tensors by accessing the dictionary `description`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = description['DebugConfig']['DebugHookConfig']['S3OutputPath']\n",
"\n",
"print('Tensors are stored in: ', path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download tensors from S3\n",
"\n",
"Now we will download the tensors from S3, so that we can visualize them in our notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"folder_name = path.split(\"/\")[-1]\n",
"os.system(\"aws s3 cp --recursive \" + path + \" \" + folder_name)\n",
"print('Downloading tensors into folder: ', folder_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize\n",
"The main purpose of this class (TensorPlot) is to visualise the tensors in your network. This could be to determine dead or saturated activations, or the features maps the network.\n",
"\n",
"To use this class (TensorPlot), you will need to supply the argument regex with the tensors you are interested in. e.g., if you are interested in activation outputs, then you need to supply the following regex .*relu|.*tanh|.*sigmoid.\n",
"\n",
"Another important argument is the `sample_batch_id`, which allows you to specify the index of the batch size to display. For example, given an input tensor of size (batch_size, channel, width, height), `sample_batch_id = n` will display (n, channel, width, height). If you set sample_batch_id = -1 then the tensors will be summed over the batch dimension (i.e., `np.sum(tensor, axis=0)`). If batch_sample_id is None then each sample will be plotted as separate layer in the figure.\n",
"\n",
"Here are some interesting use cases:\n",
"\n",
"1) If you want to determine dead or saturated activations for instance ReLus that are always outputting zero, then you would want to sum the batch dimension (sample_batch_id=-1). The sum gives an indication which parts of the network are inactive across a batch.\n",
"\n",
"2) If you are interested in the feature maps for the first image in the batch, then you should provide batch_sample_id=0. This can be helpful if your model is not performing well for certain set of samples and you want to understand which activations are leading to misprediction.\n",
"\n",
"An example visualization of layer outputs:\n",
"![](tensorplot.gif)\n",
"\n",
"\n",
"`TensorPlot` normalizes tensor values to the range 0 to 1 which means colorscales are the same across layers. Blue indicates value close to 0 and yellow indicates values close to 1. This class has been designed to plot convolutional networks that take 2D images as input and predict classes or produce output images. You can use this for other types of networks like RNNs, but you may have to adjust the class as it is currently neglecting tensors that have more than 4 dimensions.\n",
"\n",
"Let's plot Relu output activations for the given MNIST training example."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" <script type=\"text/javascript\">\n",
" window.PlotlyConfig = {MathJaxConfig: 'local'};\n",
" if (window.MathJax) {MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}\n",
" if (typeof require !== 'undefined') {\n",
" require.undef(\"plotly\");\n",
" requirejs.config({\n",
" paths: {\n",
" 'plotly': ['https://cdn.plot.ly/plotly-latest.min']\n",
" }\n",
" });\n",
" require(['plotly'], function(Plotly) {\n",
" window._Plotly = Plotly;\n",
" });\n",
" }\n",
" </script>\n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2019-11-13 10:52:31.483 186590da3c0d.ant.amazon.com:3874 INFO local_trial.py:35] Loading trial at path /tmp/mxnet3/\n",
"[2019-11-13 10:52:31.516 186590da3c0d.ant.amazon.com:3874 INFO trial.py:191] Training has ended, will refresh one final time in 1 sec.\n",
"[2019-11-13 10:52:32.519 186590da3c0d.ant.amazon.com:3874 INFO trial.py:203] Loaded all steps\n"
]
}
],
"source": [
"import tensor_plot \n",
"\n",
"visualization = tensor_plot.TensorPlot(\n",
" regex=\".*relu_output\", \n",
" path=folder_name,\n",
" steps=10, \n",
" batch_sample_id=0,\n",
" color_channel = 1,\n",
" title=\"Relu outputs\",\n",
" label=\".*sequential0_input_0\",\n",
" prediction=\".*sequential0_output_0\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we plot too many layers, it can crash the notebook. If you encounter performance or out of memroy issues, then either try to reduce the layers to plot by changing the `regex` or run this Notebook in JupyterLab instead of Jupyter. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<iframe\n",
" scrolling=\"no\"\n",
" width=\"1020px\"\n",
" height=\"820\"\n",
" src=\"iframe_figures/figure_2.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
"></iframe>\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"visualization.fig.show(renderer=\"iframe\")"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python [conda env:.conda-tf1x] *",
"language": "python",
"name": "conda-env-.conda-tf1x-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 53ed26e

Please sign in to comment.