Skip to content

Commit

Permalink
Update Flax NNX MNIST tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Oct 9, 2024
1 parent 2ad9731 commit de9442c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 98 deletions.
102 changes: 53 additions & 49 deletions docs_nnx/mnist_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb)\n",
"[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb)\n",
"\n",
"# MNIST Tutorial\n",
"# MNIST tutorial\n",
"\n",
"Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional\n",
"neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library\n",
"built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within\n",
"[Flax](https://github.com/google/flax)."
"Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.\n",
"\n",
"Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n",
"\n",
"Let’s get started!"
]
},
{
Expand All @@ -23,8 +24,7 @@
"source": [
"## 1. Install Flax\n",
"\n",
"If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the\n",
"following cell:"
"If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI - below, you’d need to uncomment and run the cell if you’re working from a Jupyter notebook/Google Colab:"
]
},
{
Expand All @@ -46,11 +46,9 @@
"id": "3",
"metadata": {},
"source": [
"## 2. Load the MNIST Dataset\n",
"## 2. Load the MNIST dataset\n",
"\n",
"First, the MNIST dataset is loaded and prepared for training and testing using\n",
"Tensorflow Datasets. Image values are normalized, the data is shuffled and divided\n",
"into batches, and samples are prefetched to enhance performance."
"First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance."
]
},
{
Expand All @@ -72,10 +70,10 @@
}
],
"source": [
"import tensorflow_datasets as tfds # TFDS for MNIST\n",
"import tensorflow as tf # TensorFlow operations\n",
"import tensorflow_datasets as tfds # TFDS for MNIST.\n",
"import tensorflow as tf # TensorFlow operations.\n",
"\n",
"tf.random.set_seed(0) # set random seed for reproducibility\n",
"tf.random.set_seed(0) # Set the random seed for reproducibility.\n",
"\n",
"train_steps = 1200\n",
"eval_every = 200\n",
Expand All @@ -95,13 +93,13 @@
" 'image': tf.cast(sample['image'], tf.float32) / 255,\n",
" 'label': sample['label'],\n",
" }\n",
") # normalize test set\n",
") # Normalize the test set.\n",
"\n",
"# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n",
"# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.\n",
"train_ds = train_ds.repeat().shuffle(1024)\n",
"# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n",
"# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n",
"train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)\n",
"# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n",
"# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n",
"test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)"
]
},
Expand All @@ -110,9 +108,9 @@
"id": "5",
"metadata": {},
"source": [
"## 3. Define the Network with Flax NNX\n",
"## 3. Define the model with Flax NNX\n",
"\n",
"Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`."
"Create a CNN for classification with Flax NNX by subclassing `nnx.Module`:"
]
},
{
Expand Down Expand Up @@ -156,7 +154,9 @@
" x = self.linear2(x)\n",
" return x\n",
"\n",
"# Instantiate the model.\n",
"model = CNN(rngs=nnx.Rngs(0))\n",
"# Visualize it.\n",
"nnx.display(model)"
]
},
Expand All @@ -165,9 +165,9 @@
"id": "7",
"metadata": {},
"source": [
"### Run model\n",
"### Run the model\n",
"\n",
"Let's put our model to the test! We'll perform a forward pass with arbitrary data and print the results."
"Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results."
]
},
{
Expand Down Expand Up @@ -203,9 +203,9 @@
"id": "9",
"metadata": {},
"source": [
"## 4. Create Optimizer and Metrics\n",
"## 4. Create the optimizer and define some metrics\n",
"\n",
"In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss."
"In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss."
]
},
{
Expand Down Expand Up @@ -247,9 +247,13 @@
"id": "13",
"metadata": {},
"source": [
"## 5. Define step functions\n",
"## 5. Define training step functions\n",
"\n",
"In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n",
"\n",
"In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. \n",
"\n",
"We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. During training, we'll use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the optimizer. During both training and testing, the loss and logits are used to calculate the metrics."
"During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics."
]
},
{
Expand Down Expand Up @@ -285,24 +289,24 @@
"id": "17",
"metadata": {},
"source": [
"The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with\n",
"In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with\n",
"[XLA](https://www.tensorflow.org/xla), optimizing performance on\n",
"hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n",
"hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n",
"except it can transforms functions that contain Flax NNX objects as inputs and outputs.\n",
"\n",
"**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code."
"> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)."
]
},
{
"cell_type": "markdown",
"id": "21",
"metadata": {},
"source": [
"## 6. Train and Evaluate\n",
"## 6. Train and evaluate the model\n",
"\n",
"Now we train a model using batches of data for 10 epochs, evaluate its performance\n",
"on the test set after each epoch, and log the training and testing metrics (loss and\n",
"accuracy) throughout the process. Typically this leads to a model with around 99% accuracy."
"Now, you can train the CNN model using batches of data for 10 epochs, evaluate the model’s performance\n",
"on the test set after each epoch, and log the training and testing metrics (the loss and\n",
"the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy."
]
},
{
Expand Down Expand Up @@ -422,25 +426,25 @@
"\n",
"for step, batch in enumerate(train_ds.as_numpy_iterator()):\n",
" # Run the optimization for one step and make a stateful update to the following:\n",
" # - the train state's model parameters\n",
" # - the optimizer state\n",
" # - the training loss and accuracy batch metrics\n",
" # - The train state's model parameters\n",
" # - The optimizer state\n",
" # - The training loss and accuracy batch metrics\n",
" train_step(model, optimizer, metrics, batch)\n",
"\n",
" if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # one training epoch has passed\n",
" # Log training metrics\n",
" for metric, value in metrics.compute().items(): # compute metrics\n",
" metrics_history[f'train_{metric}'].append(value) # record metrics\n",
" metrics.reset() # reset metrics for test set\n",
" if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.\n",
" # Log the training metrics.\n",
" for metric, value in metrics.compute().items(): # Compute the metrics.\n",
" metrics_history[f'train_{metric}'].append(value) # Record the metrics.\n",
" metrics.reset() # Reset the metrics for test set.\n",
"\n",
" # Compute metrics on the test set after each training epoch\n",
" # Compute the metrics on the test set after each training epoch.\n",
" for test_batch in test_ds.as_numpy_iterator():\n",
" eval_step(model, metrics, test_batch)\n",
"\n",
" # Log test metrics\n",
" # Log the test metrics.\n",
" for metric, value in metrics.compute().items():\n",
" metrics_history[f'test_{metric}'].append(value)\n",
" metrics.reset() # reset metrics for next training epoch\n",
" metrics.reset() # Reset the metrics for the next training epoch.\n",
"\n",
" print(\n",
" f\"[train] step: {step}, \"\n",
Expand All @@ -459,9 +463,9 @@
"id": "23",
"metadata": {},
"source": [
"## 7. Visualize Metrics\n",
"## 7. Visualize the metrics\n",
"\n",
"Use Matplotlib to create plots for loss and accuracy."
"With Matplotlib, you can create plots for the loss and the accuracy:"
]
},
{
Expand Down Expand Up @@ -503,9 +507,9 @@
"id": "25",
"metadata": {},
"source": [
"## 10. Perform inference on test set\n",
"## 10. Perform inference on the test set\n",
"\n",
"Define a jitted inference function, `pred_step`, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance."
"Create a `jit`ted model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance."
]
},
{
Expand Down Expand Up @@ -556,7 +560,7 @@
"id": "28",
"metadata": {},
"source": [
"Congratulations! You made it to the end of the annotated MNIST example."
"Congratulations! You have learned how to use Flax NNX to train a simple classification model end-to-end on the MNIST dataset."
]
}
],
Expand Down
Loading

0 comments on commit de9442c

Please sign in to comment.