From de9442cdd52e64d57a7ed90912c326ee41c8b911 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:50:55 +0000 Subject: [PATCH] Update Flax NNX MNIST tutorial --- docs_nnx/mnist_tutorial.ipynb | 102 ++++++++++++++++++---------------- docs_nnx/mnist_tutorial.md | 102 ++++++++++++++++++---------------- 2 files changed, 106 insertions(+), 98 deletions(-) diff --git a/docs_nnx/mnist_tutorial.ipynb b/docs_nnx/mnist_tutorial.ipynb index a01acbcf12..4fd3d20951 100644 --- a/docs_nnx/mnist_tutorial.ipynb +++ b/docs_nnx/mnist_tutorial.ipynb @@ -8,12 +8,13 @@ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb)\n", "\n", - "# MNIST Tutorial\n", + "# MNIST tutorial\n", "\n", - "Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional\n", - "neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library\n", - "built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within\n", - "[Flax](https://github.com/google/flax)." + "Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.\n", + "\n", + "Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n", + "\n", + "Let’s get started!" ] }, { @@ -23,8 +24,7 @@ "source": [ "## 1. Install Flax\n", "\n", - "If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the\n", - "following cell:" + "If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI - below, you’d need to uncomment and run the cell if you’re working from a Jupyter notebook/Google Colab:" ] }, { @@ -46,11 +46,9 @@ "id": "3", "metadata": {}, "source": [ - "## 2. Load the MNIST Dataset\n", + "## 2. Load the MNIST dataset\n", "\n", - "First, the MNIST dataset is loaded and prepared for training and testing using\n", - "Tensorflow Datasets. Image values are normalized, the data is shuffled and divided\n", - "into batches, and samples are prefetched to enhance performance." + "First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance." ] }, { @@ -72,10 +70,10 @@ } ], "source": [ - "import tensorflow_datasets as tfds # TFDS for MNIST\n", - "import tensorflow as tf # TensorFlow operations\n", + "import tensorflow_datasets as tfds # TFDS for MNIST.\n", + "import tensorflow as tf # TensorFlow operations.\n", "\n", - "tf.random.set_seed(0) # set random seed for reproducibility\n", + "tf.random.set_seed(0) # Set the random seed for reproducibility.\n", "\n", "train_steps = 1200\n", "eval_every = 200\n", @@ -95,13 +93,13 @@ " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", " 'label': sample['label'],\n", " }\n", - ") # normalize test set\n", + ") # Normalize the test set.\n", "\n", - "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.\n", "train_ds = train_ds.repeat().shuffle(1024)\n", - "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n", "train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)\n", - "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n", "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)" ] }, @@ -110,9 +108,9 @@ "id": "5", "metadata": {}, "source": [ - "## 3. Define the Network with Flax NNX\n", + "## 3. Define the model with Flax NNX\n", "\n", - "Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`." + "Create a CNN for classification with Flax NNX by subclassing `nnx.Module`:" ] }, { @@ -156,7 +154,9 @@ " x = self.linear2(x)\n", " return x\n", "\n", + "# Instantiate the model.\n", "model = CNN(rngs=nnx.Rngs(0))\n", + "# Visualize it.\n", "nnx.display(model)" ] }, @@ -165,9 +165,9 @@ "id": "7", "metadata": {}, "source": [ - "### Run model\n", + "### Run the model\n", "\n", - "Let's put our model to the test! We'll perform a forward pass with arbitrary data and print the results." + "Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results." ] }, { @@ -203,9 +203,9 @@ "id": "9", "metadata": {}, "source": [ - "## 4. Create Optimizer and Metrics\n", + "## 4. Create the optimizer and define some metrics\n", "\n", - "In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." + "In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." ] }, { @@ -247,9 +247,13 @@ "id": "13", "metadata": {}, "source": [ - "## 5. Define step functions\n", + "## 5. Define training step functions\n", + "\n", + "In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n", + "\n", + "In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. \n", "\n", - "We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. During training, we'll use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the optimizer. During both training and testing, the loss and logits are used to calculate the metrics." + "During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics." ] }, { @@ -285,12 +289,12 @@ "id": "17", "metadata": {}, "source": [ - "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with\n", + "In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with\n", "[XLA](https://www.tensorflow.org/xla), optimizing performance on\n", - "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", + "hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", "except it can transforms functions that contain Flax NNX objects as inputs and outputs.\n", "\n", - "**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code." + "> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)." ] }, { @@ -298,11 +302,11 @@ "id": "21", "metadata": {}, "source": [ - "## 6. Train and Evaluate\n", + "## 6. Train and evaluate the model\n", "\n", - "Now we train a model using batches of data for 10 epochs, evaluate its performance\n", - "on the test set after each epoch, and log the training and testing metrics (loss and\n", - "accuracy) throughout the process. Typically this leads to a model with around 99% accuracy." + "Now, you can train the CNN model using batches of data for 10 epochs, evaluate the model’s performance\n", + "on the test set after each epoch, and log the training and testing metrics (the loss and\n", + "the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy." ] }, { @@ -422,25 +426,25 @@ "\n", "for step, batch in enumerate(train_ds.as_numpy_iterator()):\n", " # Run the optimization for one step and make a stateful update to the following:\n", - " # - the train state's model parameters\n", - " # - the optimizer state\n", - " # - the training loss and accuracy batch metrics\n", + " # - The train state's model parameters\n", + " # - The optimizer state\n", + " # - The training loss and accuracy batch metrics\n", " train_step(model, optimizer, metrics, batch)\n", "\n", - " if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # one training epoch has passed\n", - " # Log training metrics\n", - " for metric, value in metrics.compute().items(): # compute metrics\n", - " metrics_history[f'train_{metric}'].append(value) # record metrics\n", - " metrics.reset() # reset metrics for test set\n", + " if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.\n", + " # Log the training metrics.\n", + " for metric, value in metrics.compute().items(): # Compute the metrics.\n", + " metrics_history[f'train_{metric}'].append(value) # Record the metrics.\n", + " metrics.reset() # Reset the metrics for test set.\n", "\n", - " # Compute metrics on the test set after each training epoch\n", + " # Compute the metrics on the test set after each training epoch.\n", " for test_batch in test_ds.as_numpy_iterator():\n", " eval_step(model, metrics, test_batch)\n", "\n", - " # Log test metrics\n", + " # Log the test metrics.\n", " for metric, value in metrics.compute().items():\n", " metrics_history[f'test_{metric}'].append(value)\n", - " metrics.reset() # reset metrics for next training epoch\n", + " metrics.reset() # Reset the metrics for the next training epoch.\n", "\n", " print(\n", " f\"[train] step: {step}, \"\n", @@ -459,9 +463,9 @@ "id": "23", "metadata": {}, "source": [ - "## 7. Visualize Metrics\n", + "## 7. Visualize the metrics\n", "\n", - "Use Matplotlib to create plots for loss and accuracy." + "With Matplotlib, you can create plots for the loss and the accuracy:" ] }, { @@ -503,9 +507,9 @@ "id": "25", "metadata": {}, "source": [ - "## 10. Perform inference on test set\n", + "## 10. Perform inference on the test set\n", "\n", - "Define a jitted inference function, `pred_step`, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." + "Create a `jit`ted model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." ] }, { @@ -556,7 +560,7 @@ "id": "28", "metadata": {}, "source": [ - "Congratulations! You made it to the end of the annotated MNIST example." + "Congratulations! You have learned how to use Flax NNX to train a simple classification model end-to-end on the MNIST dataset." ] } ], diff --git a/docs_nnx/mnist_tutorial.md b/docs_nnx/mnist_tutorial.md index 740395331a..d705ed9cea 100644 --- a/docs_nnx/mnist_tutorial.md +++ b/docs_nnx/mnist_tutorial.md @@ -12,19 +12,19 @@ jupytext: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb) -# MNIST Tutorial +# MNIST tutorial -Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional -neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library -built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within -[Flax](https://github.com/google/flax). +Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API. + +Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning. + +Let’s get started! +++ ## 1. Install Flax -If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the -following cell: +If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI - below, you’d need to uncomment and run the cell if you’re working from a Jupyter notebook/Google Colab: ```{code-cell} ipython3 :tags: [skip-execution] @@ -32,17 +32,15 @@ following cell: # !pip install flax ``` -## 2. Load the MNIST Dataset +## 2. Load the MNIST dataset -First, the MNIST dataset is loaded and prepared for training and testing using -Tensorflow Datasets. Image values are normalized, the data is shuffled and divided -into batches, and samples are prefetched to enhance performance. +First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance. ```{code-cell} ipython3 -import tensorflow_datasets as tfds # TFDS for MNIST -import tensorflow as tf # TensorFlow operations +import tensorflow_datasets as tfds # TFDS for MNIST. +import tensorflow as tf # TensorFlow operations. -tf.random.set_seed(0) # set random seed for reproducibility +tf.random.set_seed(0) # Set the random seed for reproducibility. train_steps = 1200 eval_every = 200 @@ -62,19 +60,19 @@ test_ds = test_ds.map( 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } -) # normalize test set +) # Normalize the test set. -# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from +# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from. train_ds = train_ds.repeat().shuffle(1024) -# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency. train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1) -# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency. test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) ``` -## 3. Define the Network with Flax NNX +## 3. Define the model with Flax NNX -Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`. +Create a CNN for classification with Flax NNX by subclassing `nnx.Module`: ```{code-cell} ipython3 from flax import nnx # Flax NNX API @@ -98,13 +96,15 @@ class CNN(nnx.Module): x = self.linear2(x) return x +# Instantiate the model. model = CNN(rngs=nnx.Rngs(0)) +# Visualize it. nnx.display(model) ``` -### Run model +### Run the model -Let's put our model to the test! We'll perform a forward pass with arbitrary data and print the results. +Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results. ```{code-cell} ipython3 :outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da @@ -115,9 +115,9 @@ y = model(jnp.ones((1, 28, 28, 1))) nnx.display(y) ``` -## 4. Create Optimizer and Metrics +## 4. Create the optimizer and define some metrics -In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. +In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. ```{code-cell} ipython3 import optax @@ -134,9 +134,13 @@ metrics = nnx.MultiMetric( nnx.display(optimizer) ``` -## 5. Define step functions +## 5. Define training step functions + +In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over. + +In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. -We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. During training, we'll use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the optimizer. During both training and testing, the loss and logits are used to calculate the metrics. +During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics. ```{code-cell} ipython3 def loss_fn(model: CNN, batch): @@ -160,20 +164,20 @@ def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates ``` -The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with +In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on -hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), +hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), except it can transforms functions that contain Flax NNX objects as inputs and outputs. -**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. +> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). +++ -## 6. Train and Evaluate +## 6. Train and evaluate the model -Now we train a model using batches of data for 10 epochs, evaluate its performance -on the test set after each epoch, and log the training and testing metrics (loss and -accuracy) throughout the process. Typically this leads to a model with around 99% accuracy. +Now, you can train the CNN model using batches of data for 10 epochs, evaluate the model’s performance +on the test set after each epoch, and log the training and testing metrics (the loss and +the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy. ```{code-cell} ipython3 :outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 @@ -187,25 +191,25 @@ metrics_history = { for step, batch in enumerate(train_ds.as_numpy_iterator()): # Run the optimization for one step and make a stateful update to the following: - # - the train state's model parameters - # - the optimizer state - # - the training loss and accuracy batch metrics + # - The train state's model parameters + # - The optimizer state + # - The training loss and accuracy batch metrics train_step(model, optimizer, metrics, batch) - if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # one training epoch has passed - # Log training metrics - for metric, value in metrics.compute().items(): # compute metrics - metrics_history[f'train_{metric}'].append(value) # record metrics - metrics.reset() # reset metrics for test set + if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed. + # Log the training metrics. + for metric, value in metrics.compute().items(): # Compute the metrics. + metrics_history[f'train_{metric}'].append(value) # Record the metrics. + metrics.reset() # Reset the metrics for test set. - # Compute metrics on the test set after each training epoch + # Compute the metrics on the test set after each training epoch. for test_batch in test_ds.as_numpy_iterator(): eval_step(model, metrics, test_batch) - # Log test metrics + # Log the test metrics. for metric, value in metrics.compute().items(): metrics_history[f'test_{metric}'].append(value) - metrics.reset() # reset metrics for next training epoch + metrics.reset() # Reset the metrics for the next training epoch. print( f"[train] step: {step}, " @@ -219,9 +223,9 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()): ) ``` -## 7. Visualize Metrics +## 7. Visualize the metrics -Use Matplotlib to create plots for loss and accuracy. +With Matplotlib, you can create plots for the loss and the accuracy: ```{code-cell} ipython3 :outputId: 431a2fcd-44fa-4202-f55a-906555f060ac @@ -240,9 +244,9 @@ ax2.legend() plt.show() ``` -## 10. Perform inference on test set +## 10. Perform inference on the test set -Define a jitted inference function, `pred_step`, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. +Create a `jit`ted model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. ```{code-cell} ipython3 @nnx.jit @@ -264,4 +268,4 @@ for i, ax in enumerate(axs.flatten()): ax.axis('off') ``` -Congratulations! You made it to the end of the annotated MNIST example. +Congratulations! You have learned how to use Flax NNX to train a simple classification model end-to-end on the MNIST dataset.