diff --git a/docs_nnx/guides/quick_start.ipynb b/docs_nnx/guides/quick_start.ipynb deleted file mode 100644 index 1c8f297726..0000000000 --- a/docs_nnx/guides/quick_start.ipynb +++ /dev/null @@ -1,568 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# NNX\n", - "\n", - "Welcome to NNX!\n", - "\n", - "NNX is an open source Python library for **N**eural **N**etwork in JA**X**. Its main feature is, much like Pytorch, allowing Python object semantics and reference sharing, which brings simplicty and familiarity, and easily crossing over into the functional world with through a set of simple APIs.\n", - "\n", - "This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using NNX and train the network for image classification on the MNIST dataset." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Installation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "! pip install -q nnx" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load the MNIST dataset\n", - "We will use the `datasets` library to load MNIST and convert it to NumPy arrays." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/cris/nnx/.venv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Found cached dataset mnist (/home/cris/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", - "100%|██████████| 2/2 [00:00<00:00, 499.95it/s]\n" - ] - } - ], - "source": [ - "import datasets\n", - "import numpy as np\n", - "\n", - "dataset = datasets.load_dataset(\"mnist\")\n", - "X_train = np.array(np.stack(dataset[\"train\"][\"image\"]), dtype=np.uint8)[\n", - " ..., None\n", - "]\n", - "y_train = np.array(dataset[\"train\"][\"label\"], dtype=np.uint8)\n", - "X_test = np.array(np.stack(dataset[\"test\"][\"image\"]), dtype=np.uint8)[..., None]\n", - "y_test = np.array(dataset[\"test\"][\"label\"], dtype=np.uint8)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Lets visualize a few examples from the dataset using matplotlib:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "# plot a 3x3 grid of MNIST digits\n", - "idxs = np.random.randint(0, len(X_train), size=(3, 3))\n", - "fig, axes = plt.subplots(3, 3, figsize=(3 * 2, 3 * 2))\n", - "\n", - "for i in range(3):\n", - " for j in range(3):\n", - " axes[i, j].imshow(X_train[idxs[i, j]], cmap=\"gray\")\n", - " axes[i, j].axis(\"off\")\n", - " axes[i, j].set_title(f\"Label: {y_train[idxs[i, j]]}\")\n", - "\n", - "plt.show()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Defining the Model\n", - "\n", - "To create a convolutional neural network using NNX define a `nnx.Module` subclass. We define the model by subclassing `nnx.Module` and defining a `forward` method that returns the model output. Like in PyTorch, the `__init__` method instantiates all the modules that will be used in the model. The `__call__` in this case\n", - "will define the forward computation. " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "data": { - "text/plain": [ - "(1, 10)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from flax import nnx\n", - "\n", - "\n", - "class CNN(nnx.Module):\n", - "\n", - " def __init__(self, *, rngs: nnx.Rngs):\n", - " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n", - " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n", - " self.linear1 = nnx.Linear(7 * 7 * 64, 256, rngs=rngs)\n", - " self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n", - " self.num_calls = nnx.var(\"counts\", 0)\n", - "\n", - " def __call__(self, x: jax.Array) -> jax.Array:\n", - " self.num_calls += 1\n", - " x = self.conv1(x)\n", - " x = nnx.relu(x)\n", - " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = self.conv2(x)\n", - " x = nnx.relu(x)\n", - " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = x.reshape((x.shape[0], -1)) # flatten\n", - " x = self.linear1(x)\n", - " x = nnx.relu(x)\n", - " x = self.linear2(x)\n", - " return x\n", - "\n", - "\n", - "model = CNN(rngs=nnx.Rngs(0))\n", - "\n", - "y = model(X_train[:1])\n", - "y.shape" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "One notable difference with other frameworks is that `__init__`, by convention, accepts a `rngs: nnx.Rngs` keyword-only argument. This object is passed around to generate PRNG keys as random state is explicit in JAX.\n", - "\n", - "One of the nice things about NNX is that Module contain their own state, are fully inspectable, and you can run them eargerly. For example, we can easily check out the kernel shape of the first `Conv` layer:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(3, 3, 1, 32)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.conv1.kernel.shape" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also view the entire `State` of the model using the `.filter()` method. TODO: talk about collections." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'conv1/bias': Variable(\n", - " collection='params',\n", - " value=(32,)\n", - " ),\n", - " 'conv1/kernel': Variable(\n", - " collection='params',\n", - " value=(3, 3, 1, 32)\n", - " ),\n", - " 'conv2/bias': Variable(\n", - " collection='params',\n", - " value=(64,)\n", - " ),\n", - " 'conv2/kernel': Variable(\n", - " collection='params',\n", - " value=(3, 3, 32, 64)\n", - " ),\n", - " 'linear1/bias': Variable(\n", - " collection='params',\n", - " value=(256,)\n", - " ),\n", - " 'linear1/kernel': Variable(\n", - " collection='params',\n", - " value=(3136, 256)\n", - " ),\n", - " 'linear2/bias': Variable(\n", - " collection='params',\n", - " value=(10,)\n", - " ),\n", - " 'linear2/kernel': Variable(\n", - " collection='params',\n", - " value=(256, 10)\n", - " )\n", - "})" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.tree.map(jnp.shape, model.extract(nnx.Param))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training in eager mode\n", - "\n", - "For pedagogical purposes, we first train the model in eager mode. This will be uselful to take a look at some of NNX's features, its be more approachable for new users, and great for debugging, but it is not the recommended way to train models in JAX.\n", - "\n", - "Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree.map` operation. Finally, we will update the model's parameters using the `.update_state` method." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 0: loss=58.7676\n", - "Step 1: loss=80.0420\n", - "Step 2: loss=108.3005\n", - "Step 3: loss=26.6188\n", - "Step 4: loss=10.7236\n", - "Step 5: loss=4.7499\n", - "Step 6: loss=3.9177\n", - "Step 7: loss=2.9419\n", - "Step 8: loss=2.4733\n", - "Step 9: loss=1.8060\n" - ] - } - ], - "source": [ - "import optax\n", - "\n", - "for step in range(10):\n", - " idxs = np.random.randint(0, len(X_train), size=32)\n", - " x = jnp.array(X_train[idxs])\n", - " y = jnp.array(y_train[idxs])\n", - "\n", - " def loss_fn(model: CNN):\n", - " logits = model(x)\n", - " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", - "\n", - " loss, grads = nnx.value_and_grad(loss_fn, wrt=\"params\")(model)\n", - " params = model.extract(\"params\")\n", - " params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n", - "\n", - " model.update(params)\n", - " print(f\"Step {step}: loss={loss:.4f}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The loss is going down 🎉.\n", - "\n", - "### Training with the Functional API\n", - "\n", - "Now that we have a working model, lets see how to train it with `jax.jit` using NNX's Functional API. The `Module.split` method allows you to convert a Module into pytrees with functional semantics, this allows you to integrate with JAX's functional APIs like `jax.jit` and `jax.grad`.\n", - "\n", - "In this next example we will use the `.split` method to split the model into a `params: State` and `graphdef: GraphDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `graphdef` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `GraphDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "graphdef, params = model.split(\"params\")\n", - "\n", - "\n", - "@jax.jit\n", - "def train_step(params: nnx.State, x, y):\n", - " def loss_fn(params):\n", - " logits, _updates = graphdef.apply(params)(x)\n", - " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", - "\n", - " loss, grads = jax.value_and_grad(loss_fn)(params)\n", - " params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n", - "\n", - " return loss, params" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Using `train_step` we can run a few more iterations and see that the loss is still going down, however, this time execution should be much faster." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 0: loss=1.4396\n", - "Step 1: loss=1.4127\n", - "Step 2: loss=1.8718\n", - "Step 3: loss=1.7080\n", - "Step 4: loss=1.7984\n", - "Step 5: loss=1.0350\n", - "Step 6: loss=1.2076\n", - "Step 7: loss=0.9081\n", - "Step 8: loss=0.8217\n", - "Step 9: loss=0.6687\n" - ] - } - ], - "source": [ - "for step in range(10):\n", - " idxs = np.random.randint(0, len(X_train), size=32)\n", - " x = jnp.array(X_train[idxs])\n", - " y = jnp.array(y_train[idxs])\n", - "\n", - " loss, params = train_step(params, x, y)\n", - " print(f\"Step {step}: loss={loss:.4f}\")\n", - "\n", - "model.update(params)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Realistic Training using TrainState\n", - "\n", - "For real training scenarios, we recommend using `TrainState` to manage the state of your training loop. `TrainState` manages the `params` of your network along with other types of state, and uses `optax` to update the parameters according to the gradients.\n", - "\n", - "Next, we will define a `train_step` function that accepts a `TrainState` and a batch of data, and returns a new `TrainState` with updated parameters. The `apply_gradients` method will return a new `state` with the updated parameters. Flax users should be familiar with this API. In this case will will also define a `eval_step` function that will be used to evaluate the model on the test set and return some metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "state = nnx.TrainState(\n", - " graphdef,\n", - " params=params,\n", - " tx=optax.adam(0.001),\n", - ")\n", - "\n", - "\n", - "@jax.jit\n", - "def train_step(state: nnx.TrainState, x, y):\n", - " def loss_fn(params):\n", - " logits, _updates = state.apply_fn(params)(x)\n", - " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", - "\n", - " grads = jax.grad(loss_fn)(state.params)\n", - "\n", - " state = state.apply_gradients(grads=grads)\n", - "\n", - " return state\n", - "\n", - "\n", - "@jax.jit\n", - "def eval_step(state: nnx.TrainState, x, y):\n", - " logits, _updates = state.apply_fn(state.params)(x)\n", - " metrics = {\n", - " 'accuracy': jnp.mean(jnp.argmax(logits, axis=-1) == y),\n", - " 'loss': optax.softmax_cross_entropy_with_integer_labels(logits, y).mean(),\n", - " }\n", - " return metrics" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now lets create a simple training loop that runs for 1000 iterations and prints the metrics every 100 steps. At the end of training we will compute the final metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 0: {'accuracy': Array(0.63119996, dtype=float32), 'loss': Array(1.1837534, dtype=float32)}\n", - "Step 100: {'accuracy': Array(0.9492, dtype=float32), 'loss': Array(0.16359854, dtype=float32)}\n", - "Step 200: {'accuracy': Array(0.9564, dtype=float32), 'loss': Array(0.14198248, dtype=float32)}\n", - "Step 300: {'accuracy': Array(0.96279997, dtype=float32), 'loss': Array(0.12757339, dtype=float32)}\n", - "Step 400: {'accuracy': Array(0.97169995, dtype=float32), 'loss': Array(0.09900841, dtype=float32)}\n", - "Step 500: {'accuracy': Array(0.96889997, dtype=float32), 'loss': Array(0.10143881, dtype=float32)}\n", - "Step 600: {'accuracy': Array(0.9745, dtype=float32), 'loss': Array(0.08513925, dtype=float32)}\n", - "Step 700: {'accuracy': Array(0.96379995, dtype=float32), 'loss': Array(0.11632324, dtype=float32)}\n", - "Step 800: {'accuracy': Array(0.97679996, dtype=float32), 'loss': Array(0.07204168, dtype=float32)}\n", - "Step 900: {'accuracy': Array(0.9765, dtype=float32), 'loss': Array(0.08413408, dtype=float32)}\n", - "Final metrics: {'accuracy': Array(0.9819, dtype=float32), 'loss': Array(0.05711861, dtype=float32)}\n" - ] - } - ], - "source": [ - "total_steps = 1000\n", - "eval_every = 100\n", - "\n", - "for step in range(total_steps):\n", - " if step % eval_every == 0:\n", - " metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", - " print(f\"Step {step}: {metrics}\")\n", - "\n", - " idxs = np.random.randint(0, len(X_train), size=32)\n", - " x = jnp.array(X_train[idxs])\n", - " y = jnp.array(y_train[idxs])\n", - "\n", - " state = train_step(state, x, y)\n", - "\n", - "metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", - "print(f\"Final metrics: {metrics}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Inference\n", - "\n", - "Finally, now that we have a trained model, lets use it to make some predictions. We will update the `model` object with the trained parameters and use it to make predictions on the test set." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "model.update(state.params)\n", - "\n", - "# plot a 3x3 grid of MNIST digits\n", - "idxs = np.random.randint(0, len(X_test), size=(3, 3))\n", - "fig, axes = plt.subplots(3, 3, figsize=(3 * 2, 3 * 2))\n", - "\n", - "for i in range(3):\n", - " for j in range(3):\n", - " logits = model(jnp.array([X_test[idxs[i, j]]]))\n", - " axes[i, j].imshow(X_test[idxs[i, j]], cmap=\"gray\")\n", - " axes[i, j].axis(\"off\")\n", - " axes[i, j].set_title(f\"Prediction: {jnp.argmax(logits)}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Awesome! We hope you've enjoyed this tutorial and learned the basics of NNX." - ] - } - ], - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs_nnx/quick_start.ipynb b/docs_nnx/quick_start.ipynb deleted file mode 100644 index 32530b9bed..0000000000 --- a/docs_nnx/quick_start.ipynb +++ /dev/null @@ -1,701 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "6eea21b3", - "metadata": {}, - "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb)\n", - "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb)\n", - "\n", - "# Quick start\n", - "\n", - "Welcome to Flax!\n", - "\n", - "Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural\n", - "network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train\n", - "the network for image classification on the MNIST dataset." - ] - }, - { - "cell_type": "markdown", - "id": "nwJWKIhdwxDo", - "metadata": {}, - "source": [ - "## 1. Install Flax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb81587e", - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], - "source": [ - "!pip install -q flax>=0.7.5" - ] - }, - { - "cell_type": "markdown", - "id": "b529fbef", - "metadata": {}, - "source": [ - "## 2. Loading data\n", - "\n", - "Flax can use any\n", - "data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the\n", - "samples to floating-point numbers." - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "bRlrHqZVXZvk", - "metadata": {}, - "outputs": [], - "source": [ - "import tensorflow_datasets as tfds # TFDS for MNIST\n", - "import tensorflow as tf # TensorFlow operations\n", - "\n", - "def get_datasets(num_epochs, batch_size):\n", - " \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n", - " train_ds = tfds.load('mnist', split='train')\n", - " test_ds = tfds.load('mnist', split='test')\n", - "\n", - " train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", - " tf.float32) / 255.,\n", - " 'label': sample['label']}) # normalize train set\n", - " test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", - " tf.float32) / 255.,\n", - " 'label': sample['label']}) # normalize test set\n", - "\n", - " train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - " train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", - " test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - " test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", - "\n", - " return train_ds, test_ds" - ] - }, - { - "cell_type": "markdown", - "id": "7057395a", - "metadata": {}, - "source": [ - "## 3. Define network\n", - "\n", - "Create a convolutional neural network with the Linen API by subclassing\n", - "[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n", - "Because the architecture in this example is relatively simple—you're just\n", - "stacking layers—you can define the inlined submodules directly within the\n", - "`__call__` method and wrap it with the\n", - "[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact)\n", - "decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide." - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "cbc079cd", - "metadata": {}, - "outputs": [], - "source": [ - "from flax import linen as nn # Linen API\n", - "\n", - "class CNN(nn.Module):\n", - " \"\"\"A simple CNN model.\"\"\"\n", - "\n", - " @nn.compact\n", - " def __call__(self, x):\n", - " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", - " x = nn.relu(x)\n", - " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", - " x = nn.relu(x)\n", - " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = x.reshape((x.shape[0], -1)) # flatten\n", - " x = nn.Dense(features=256)(x)\n", - " x = nn.relu(x)\n", - " x = nn.Dense(features=10)(x)\n", - " return x" - ] - }, - { - "cell_type": "markdown", - "id": "hy7iRu7_zlx-", - "metadata": {}, - "source": [ - "### View model layers\n", - "\n", - "Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input." - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "lDHfog81zLQa", - "metadata": { - "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\u001b[3m CNN Summary \u001b[0m\n", - "┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodule\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1minputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1moutputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mflops \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mvjp_flops\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mparams \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", - "│ │ CNN │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 8708106 │ 26957556 │ │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Conv_0 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 455424 │ 1341472 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2mKB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Conv_1 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 6566144 │ 19704320 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[6… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m18,496 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2m(74.0 KB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Dense_0 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 1605888 │ 5620224 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m803,072 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2m(3.2 MB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Dense_1 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 5130 │ 17940 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[1… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m2,570 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2m(10.3 KB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m824,458 \u001b[0m\u001b[1m \u001b[0m│\n", - "│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", - "└─────────┴────────┴────────────┴───────────┴─────────┴───────────┴────────────┘\n", - "\u001b[1m \u001b[0m\n", - "\u001b[1m Total Parameters: 824,458 \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\n", - "\n", - "\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp # JAX NumPy\n", - "\n", - "cnn = CNN()\n", - "print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)),\n", - " compute_flops=True, compute_vjp_flops=True))" - ] - }, - { - "cell_type": "markdown", - "id": "4b5ac16e", - "metadata": {}, - "source": [ - "## 4. Create a `TrainState`\n", - "\n", - "A common pattern in Flax is to create a single dataclass that represents the\n", - "entire training state, including step number, parameters, and optimizer state.\n", - "\n", - "Because this is such a common pattern, Flax provides the class\n", - "[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state)\n", - "that serves most basic usecases." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "qXr7JDpIxGNZ", - "metadata": { - "outputId": "1249b7fb-6787-41eb-b34c-61d736300844" - }, - "outputs": [], - "source": [ - "!pip install -q clu" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "CJDaJNijyOji", - "metadata": {}, - "outputs": [], - "source": [ - "from clu import metrics\n", - "from flax.training import train_state # Useful dataclass to keep train state\n", - "from flax import struct # Flax dataclasses\n", - "import optax # Common loss functions and optimizers" - ] - }, - { - "cell_type": "markdown", - "id": "8b86b5f1", - "metadata": {}, - "source": [ - "We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ)." - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "7W0qf7FC9uG5", - "metadata": {}, - "outputs": [], - "source": [ - "@struct.dataclass\n", - "class Metrics(metrics.Collection):\n", - " accuracy: metrics.Accuracy\n", - " loss: metrics.Average.from_output('loss')" - ] - }, - { - "cell_type": "markdown", - "id": "f3ce5e4c", - "metadata": {}, - "source": [ - "You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need\n", - "to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once." - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "e0102447", - "metadata": {}, - "outputs": [], - "source": [ - "class TrainState(train_state.TrainState):\n", - " metrics: Metrics\n", - "\n", - "def create_train_state(module, rng, learning_rate, momentum):\n", - " \"\"\"Creates an initial `TrainState`.\"\"\"\n", - " params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image\n", - " tx = optax.sgd(learning_rate, momentum)\n", - " return TrainState.create(\n", - " apply_fn=module.apply, params=params, tx=tx,\n", - " metrics=Metrics.empty())" - ] - }, - { - "cell_type": "markdown", - "id": "a15de484", - "metadata": {}, - "source": [ - "## 5. Training step\n", - "\n", - "A function that:\n", - "\n", - "- Evaluates the neural network given the parameters and a batch of input images\n", - " with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply)\n", - " method (forward pass)).\n", - "- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.\n", - "- Evaluates the gradient of the loss function using\n", - " [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad).\n", - "- Applies a\n", - " [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)\n", - " of gradients to the optimizer to update the model's parameters.\n", - "\n", - "Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n", - "decorator to trace the entire `train_step` function and just-in-time compile\n", - "it with [XLA](https://www.tensorflow.org/xla) into fused device operations\n", - "that run faster and more efficiently on hardware accelerators." - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "9b0af486", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def train_step(state, batch):\n", - " \"\"\"Train for a single step.\"\"\"\n", - " def loss_fn(params):\n", - " logits = state.apply_fn({'params': params}, batch['image'])\n", - " loss = optax.softmax_cross_entropy_with_integer_labels(\n", - " logits=logits, labels=batch['label']).mean()\n", - " return loss\n", - " grad_fn = jax.grad(loss_fn)\n", - " grads = grad_fn(state.params)\n", - " state = state.apply_gradients(grads=grads)\n", - " return state" - ] - }, - { - "cell_type": "markdown", - "id": "0ff5145f", - "metadata": {}, - "source": [ - "## 6. Metric computation\n", - "\n", - "Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`." - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "961bf70b", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def compute_metrics(*, state, batch):\n", - " logits = state.apply_fn({'params': state.params}, batch['image'])\n", - " loss = optax.softmax_cross_entropy_with_integer_labels(\n", - " logits=logits, labels=batch['label']).mean()\n", - " metric_updates = state.metrics.single_from_model_output(\n", - " logits=logits, labels=batch['label'], loss=loss)\n", - " metrics = state.metrics.merge(metric_updates)\n", - " state = state.replace(metrics=metrics)\n", - " return state" - ] - }, - { - "cell_type": "markdown", - "id": "497241c3", - "metadata": {}, - "source": [ - "## 7. Download data" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "bff5393e", - "metadata": {}, - "outputs": [], - "source": [ - "num_epochs = 10\n", - "batch_size = 32\n", - "\n", - "train_ds, test_ds = get_datasets(num_epochs, batch_size)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "809ae1a0", - "metadata": {}, - "source": [ - "## 8. Seed randomness\n", - "\n", - "- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible.\n", - "- Get one\n", - " [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey)\n", - " and use it for parameter initialization. (Learn\n", - " more about\n", - " [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)\n", - " and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).)" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "xC4MFyBsfT-U", - "metadata": {}, - "outputs": [], - "source": [ - "tf.random.set_seed(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "e4f6f4d3", - "metadata": {}, - "outputs": [], - "source": [ - "init_rng = jax.random.key(0)" - ] - }, - { - "cell_type": "markdown", - "id": "80fbb60b", - "metadata": {}, - "source": [ - "## 9. Initialize the `TrainState`\n", - "\n", - "Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics\n", - "and puts them into the training state dataclass that is returned." - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "445fcab0", - "metadata": {}, - "outputs": [], - "source": [ - "learning_rate = 0.01\n", - "momentum = 0.9" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "id": "5221eafd", - "metadata": {}, - "outputs": [], - "source": [ - "state = create_train_state(cnn, init_rng, learning_rate, momentum)\n", - "del init_rng # Must not be used anymore." - ] - }, - { - "cell_type": "markdown", - "id": "b1c00230", - "metadata": {}, - "source": [ - "## 10. Train and evaluate\n", - "\n", - "Create a \"shuffled\" dataset by:\n", - "- Repeating the dataset equal to the number of training epochs\n", - "- Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from\n", - " - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer\n", - "\n", - "Define a training loop that:\n", - "- Randomly samples batches from the dataset.\n", - "- Runs an optimization step for each training batch.\n", - "- Computes the mean training metrics across each batch in an epoch.\n", - "- Computes the metrics for the test set using the updated parameters.\n", - "- Records the train and test metrics for visualization.\n", - "\n", - "Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": 62, - "id": "74295360", - "metadata": {}, - "outputs": [], - "source": [ - "# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs\n", - "num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs" - ] - }, - { - "cell_type": "code", - "execution_count": 63, - "id": "cRtnMZuQFlKl", - "metadata": {}, - "outputs": [], - "source": [ - "metrics_history = {'train_loss': [],\n", - " 'train_accuracy': [],\n", - " 'test_loss': [],\n", - " 'test_accuracy': []}" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "2c40ce90", - "metadata": { - "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "train epoch: 1, loss: 0.20290373265743256, accuracy: 93.87000274658203\n", - "test epoch: 1, loss: 0.07591685652732849, accuracy: 97.60617065429688\n", - "train epoch: 2, loss: 0.05760224163532257, accuracy: 98.28500366210938\n", - "test epoch: 2, loss: 0.050395529717206955, accuracy: 98.3974380493164\n", - "train epoch: 3, loss: 0.03897436335682869, accuracy: 98.83000183105469\n", - "test epoch: 3, loss: 0.04574578255414963, accuracy: 98.54767608642578\n", - "train epoch: 4, loss: 0.028721099719405174, accuracy: 99.15166473388672\n", - "test epoch: 4, loss: 0.035722777247428894, accuracy: 98.91827392578125\n", - "train epoch: 5, loss: 0.021948494017124176, accuracy: 99.37999725341797\n", - "test epoch: 5, loss: 0.035723842680454254, accuracy: 98.87820434570312\n", - "train epoch: 6, loss: 0.01705147698521614, accuracy: 99.54833221435547\n", - "test epoch: 6, loss: 0.03456473350524902, accuracy: 98.96835327148438\n", - "train epoch: 7, loss: 0.014007646590471268, accuracy: 99.6116714477539\n", - "test epoch: 7, loss: 0.04089202359318733, accuracy: 98.7880630493164\n", - "train epoch: 8, loss: 0.011265480890870094, accuracy: 99.73333740234375\n", - "test epoch: 8, loss: 0.03337760642170906, accuracy: 98.93830108642578\n", - "train epoch: 9, loss: 0.00918484665453434, accuracy: 99.78334045410156\n", - "test epoch: 9, loss: 0.034478139132261276, accuracy: 98.96835327148438\n", - "train epoch: 10, loss: 0.007260234095156193, accuracy: 99.84166717529297\n", - "test epoch: 10, loss: 0.032822880893945694, accuracy: 99.07852172851562\n" - ] - } - ], - "source": [ - "for step,batch in enumerate(train_ds.as_numpy_iterator()):\n", - "\n", - " # Run optimization steps over training batches and compute batch metrics\n", - " state = train_step(state, batch) # get updated train state (which contains the updated parameters)\n", - " state = compute_metrics(state=state, batch=batch) # aggregate batch metrics\n", - "\n", - " if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed\n", - " for metric,value in state.metrics.compute().items(): # compute metrics\n", - " metrics_history[f'train_{metric}'].append(value) # record metrics\n", - " state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch\n", - "\n", - " # Compute metrics on the test set after each training epoch\n", - " test_state = state\n", - " for test_batch in test_ds.as_numpy_iterator():\n", - " test_state = compute_metrics(state=test_state, batch=test_batch)\n", - "\n", - " for metric,value in test_state.metrics.compute().items():\n", - " metrics_history[f'test_{metric}'].append(value)\n", - "\n", - " print(f\"train epoch: {(step+1) // num_steps_per_epoch}, \"\n", - " f\"loss: {metrics_history['train_loss'][-1]}, \"\n", - " f\"accuracy: {metrics_history['train_accuracy'][-1] * 100}\")\n", - " print(f\"test epoch: {(step+1) // num_steps_per_epoch}, \"\n", - " f\"loss: {metrics_history['test_loss'][-1]}, \"\n", - " f\"accuracy: {metrics_history['test_accuracy'][-1] * 100}\")" - ] - }, - { - "cell_type": "markdown", - "id": "gfsecJzvzgCT", - "metadata": {}, - "source": [ - "## 11. Visualize metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "id": "Zs5atiqIG9Kz", - "metadata": { - "outputId": "431a2fcd-44fa-4202-f55a-906555f060ac" - }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt # Visualization\n", - "\n", - "# Plot loss and accuracy in subplots\n", - "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", - "ax1.set_title('Loss')\n", - "ax2.set_title('Accuracy')\n", - "for dataset in ('train','test'):\n", - " ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')\n", - " ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n", - "ax1.legend()\n", - "ax2.legend()\n", - "plt.show()\n", - "plt.clf()" - ] - }, - { - "cell_type": "markdown", - "id": "qQbKS0tV3sZ1", - "metadata": {}, - "source": [ - "## 12. Perform inference on test set\n", - "\n", - "Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels." - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "DFwxgBQf44ks", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def pred_step(state, batch):\n", - " logits = state.apply_fn({'params': state.params}, test_batch['image'])\n", - " return logits.argmax(axis=1)\n", - "\n", - "test_batch = test_ds.as_numpy_iterator().next()\n", - "pred = pred_step(state, test_batch)" - ] - }, - { - "cell_type": "code", - "execution_count": 67, - "id": "5d5nF3u44JFI", - "metadata": { - "outputId": "1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e" - }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n", - "for i, ax in enumerate(axs.flatten()):\n", - " ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')\n", - " ax.set_title(f\"label={pred[i]}\")\n", - " ax.axis('off')" - ] - }, - { - "cell_type": "markdown", - "id": "edb528b6", - "metadata": {}, - "source": [ - "Congratulations! You made it to the end of the annotated MNIST example. You can revisit\n", - "the same example, but structured differently as a couple of Python modules, test\n", - "modules, config files, another Colab, and documentation in Flax's Git repo:\n", - "\n", - "[https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist)" - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst", - "main_language": "python" - }, - "language_info": { - "name": "python", - "version": "3.9.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs_nnx/quick_start.md b/docs_nnx/quick_start.md deleted file mode 100644 index ac8a9fb860..0000000000 --- a/docs_nnx/quick_start.md +++ /dev/null @@ -1,355 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - main_language: python - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.13.8 ---- - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb) -[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb) - -# Quick start - -Welcome to Flax! - -Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural -network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train -the network for image classification on the MNIST dataset. - -+++ - -## 1. Install Flax - -```{code-cell} -:tags: [skip-execution] - -!pip install -q flax>=0.7.5 -``` - -## 2. Loading data - -Flax can use any -data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the -samples to floating-point numbers. - -```{code-cell} -import tensorflow_datasets as tfds # TFDS for MNIST -import tensorflow as tf # TensorFlow operations - -def get_datasets(num_epochs, batch_size): - """Load MNIST train and test datasets into memory.""" - train_ds = tfds.load('mnist', split='train') - test_ds = tfds.load('mnist', split='test') - - train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'], - tf.float32) / 255., - 'label': sample['label']}) # normalize train set - test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'], - tf.float32) / 255., - 'label': sample['label']}) # normalize test set - - train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from - train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency - test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from - test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency - - return train_ds, test_ds -``` - -## 3. Define network - -Create a convolutional neural network with the Linen API by subclassing -[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html). -Because the architecture in this example is relatively simple—you're just -stacking layers—you can define the inlined submodules directly within the -`__call__` method and wrap it with the -[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact) -decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide. - -```{code-cell} -from flax import linen as nn # Linen API - -class CNN(nn.Module): - """A simple CNN model.""" - - @nn.compact - def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=256)(x) - x = nn.relu(x) - x = nn.Dense(features=10)(x) - return x -``` - -### View model layers - -Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. - -```{code-cell} -:outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da - -import jax -import jax.numpy as jnp # JAX NumPy - -cnn = CNN() -print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)), - compute_flops=True, compute_vjp_flops=True)) -``` - -## 4. Create a `TrainState` - -A common pattern in Flax is to create a single dataclass that represents the -entire training state, including step number, parameters, and optimizer state. - -Because this is such a common pattern, Flax provides the class -[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state) -that serves most basic usecases. - -```{code-cell} -:outputId: 1249b7fb-6787-41eb-b34c-61d736300844 - -!pip install -q clu -``` - -```{code-cell} -from clu import metrics -from flax.training import train_state # Useful dataclass to keep train state -from flax import struct # Flax dataclasses -import optax # Common loss functions and optimizers -``` - -We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ). - -```{code-cell} -@struct.dataclass -class Metrics(metrics.Collection): - accuracy: metrics.Accuracy - loss: metrics.Average.from_output('loss') -``` - -You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need -to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once. - -```{code-cell} -class TrainState(train_state.TrainState): - metrics: Metrics - -def create_train_state(module, rng, learning_rate, momentum): - """Creates an initial `TrainState`.""" - params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image - tx = optax.sgd(learning_rate, momentum) - return TrainState.create( - apply_fn=module.apply, params=params, tx=tx, - metrics=Metrics.empty()) -``` - -## 5. Training step - -A function that: - -- Evaluates the neural network given the parameters and a batch of input images - with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) - method (forward pass)). -- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding. -- Evaluates the gradient of the loss function using - [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad). -- Applies a - [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions) - of gradients to the optimizer to update the model's parameters. - -Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit) -decorator to trace the entire `train_step` function and just-in-time compile -it with [XLA](https://www.tensorflow.org/xla) into fused device operations -that run faster and more efficiently on hardware accelerators. - -```{code-cell} -@jax.jit -def train_step(state, batch): - """Train for a single step.""" - def loss_fn(params): - logits = state.apply_fn({'params': params}, batch['image']) - loss = optax.softmax_cross_entropy_with_integer_labels( - logits=logits, labels=batch['label']).mean() - return loss - grad_fn = jax.grad(loss_fn) - grads = grad_fn(state.params) - state = state.apply_gradients(grads=grads) - return state -``` - -## 6. Metric computation - -Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`. - -```{code-cell} -@jax.jit -def compute_metrics(*, state, batch): - logits = state.apply_fn({'params': state.params}, batch['image']) - loss = optax.softmax_cross_entropy_with_integer_labels( - logits=logits, labels=batch['label']).mean() - metric_updates = state.metrics.single_from_model_output( - logits=logits, labels=batch['label'], loss=loss) - metrics = state.metrics.merge(metric_updates) - state = state.replace(metrics=metrics) - return state -``` - -## 7. Download data - -```{code-cell} -num_epochs = 10 -batch_size = 32 - -train_ds, test_ds = get_datasets(num_epochs, batch_size) -``` - -## 8. Seed randomness - -- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible. -- Get one - [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey) - and use it for parameter initialization. (Learn - more about - [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) - and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) - -```{code-cell} -tf.random.set_seed(0) -``` - -```{code-cell} -init_rng = jax.random.key(0) -``` - -## 9. Initialize the `TrainState` - -Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics -and puts them into the training state dataclass that is returned. - -```{code-cell} -learning_rate = 0.01 -momentum = 0.9 -``` - -```{code-cell} -state = create_train_state(cnn, init_rng, learning_rate, momentum) -del init_rng # Must not be used anymore. -``` - -## 10. Train and evaluate - -Create a "shuffled" dataset by: -- Repeating the dataset equal to the number of training epochs -- Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from - - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer - -Define a training loop that: -- Randomly samples batches from the dataset. -- Runs an optimization step for each training batch. -- Computes the mean training metrics across each batch in an epoch. -- Computes the metrics for the test set using the updated parameters. -- Records the train and test metrics for visualization. - -Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy. - -```{code-cell} -# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs -num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs -``` - -```{code-cell} -metrics_history = {'train_loss': [], - 'train_accuracy': [], - 'test_loss': [], - 'test_accuracy': []} -``` - -```{code-cell} -:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 - -for step,batch in enumerate(train_ds.as_numpy_iterator()): - - # Run optimization steps over training batches and compute batch metrics - state = train_step(state, batch) # get updated train state (which contains the updated parameters) - state = compute_metrics(state=state, batch=batch) # aggregate batch metrics - - if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed - for metric,value in state.metrics.compute().items(): # compute metrics - metrics_history[f'train_{metric}'].append(value) # record metrics - state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch - - # Compute metrics on the test set after each training epoch - test_state = state - for test_batch in test_ds.as_numpy_iterator(): - test_state = compute_metrics(state=test_state, batch=test_batch) - - for metric,value in test_state.metrics.compute().items(): - metrics_history[f'test_{metric}'].append(value) - - print(f"train epoch: {(step+1) // num_steps_per_epoch}, " - f"loss: {metrics_history['train_loss'][-1]}, " - f"accuracy: {metrics_history['train_accuracy'][-1] * 100}") - print(f"test epoch: {(step+1) // num_steps_per_epoch}, " - f"loss: {metrics_history['test_loss'][-1]}, " - f"accuracy: {metrics_history['test_accuracy'][-1] * 100}") -``` - -## 11. Visualize metrics - -```{code-cell} -:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac - -import matplotlib.pyplot as plt # Visualization - -# Plot loss and accuracy in subplots -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) -ax1.set_title('Loss') -ax2.set_title('Accuracy') -for dataset in ('train','test'): - ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') - ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') -ax1.legend() -ax2.legend() -plt.show() -plt.clf() -``` - -## 12. Perform inference on test set - -Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels. - -```{code-cell} -@jax.jit -def pred_step(state, batch): - logits = state.apply_fn({'params': state.params}, test_batch['image']) - return logits.argmax(axis=1) - -test_batch = test_ds.as_numpy_iterator().next() -pred = pred_step(state, test_batch) -``` - -```{code-cell} -:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e - -fig, axs = plt.subplots(5, 5, figsize=(12, 12)) -for i, ax in enumerate(axs.flatten()): - ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') - ax.set_title(f"label={pred[i]}") - ax.axis('off') -``` - -Congratulations! You made it to the end of the annotated MNIST example. You can revisit -the same example, but structured differently as a couple of Python modules, test -modules, config files, another Colab, and documentation in Flax's Git repo: - -[https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist) diff --git a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py new file mode 100644 index 0000000000..f5cf8002b5 --- /dev/null +++ b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py @@ -0,0 +1,171 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import os +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' + +from matplotlib import pyplot as plt +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +import typing as tp + +mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh((2, 4)), + ('data', 'model'), +) + + +def named_sharding(*names: str | None) -> NamedSharding: + return NamedSharding(mesh, P(*names)) + + +@dataclasses.dataclass(unsafe_hash=True) +class MeshRules: + embed: str | None = None + mlp: str | None = None + data: str | None = None + + def __call__(self, *keys: str) -> tuple[str, ...]: + return tuple(getattr(self, key) for key in keys) + + +mesh_rules = MeshRules( + embed=None, + mlp='model', + data='data', +) + + +class MLP(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.w1 = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)), + sharding=mesh_rules('embed', 'mlp'), + ) + self.b1 = nnx.Param( + jnp.zeros((dmid,)), + sharding=mesh_rules('mlp'), + ) + self.w2 = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)), + sharding=mesh_rules('embed', 'mlp'), + ) + + def __call__(self, x: jax.Array): + return nnx.relu(x @ self.w1 + self.b1) @ self.w2 + + +class SGDState(nnx.Variable): + pass + + +class SGD(nnx.Object): + def __init__(self, params: nnx.State, lr, decay=0.9): + def init_optimizer_state(variable: nnx.Variable): + return SGDState( + jnp.zeros_like(variable.value), **variable.get_metadata() + ) + + self.lr = lr + self.params = params + self.momentum = jax.tree.map(init_optimizer_state, self.params) + self.decay = decay + + def update(self, grads: nnx.State): + def update_fn( + params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState + ): + # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t) + momentum.value = self.decay * momentum + (1 - self.decay) * grad.value + # θ_{t+1} = θ_t - α * v_t + params.value -= self.lr * momentum + + jax.tree.map(update_fn, self.params, self.momentum, grads) + + +@nnx.jit +def create_model(): + model = MLP(1, 32, 1, rngs=nnx.Rngs(0)) + optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9) + state = nnx.state(optimizer) + sharded_state = jax.lax.with_sharding_constraint( + state, nnx.get_named_sharding(state, mesh) + ) + + def get_named_shardings(path: tuple, value: nnx.VariableState): + if path[0] == 'params': + return value.replace(NamedSharding(mesh, P(*value.sharding))) + elif path[0] == 'momentum': + # currently the same as above but in general it could be different + return value.replace(NamedSharding(mesh, P(*value.sharding))) + else: + raise ValueError(f'Unknown path: {path}') + + named_shardings = state.map(get_named_shardings) + sharded_state = jax.lax.with_sharding_constraint(state, named_shardings) + nnx.update(optimizer, sharded_state) + return model, optimizer + + +model, optimizer = create_model() + +jax.debug.visualize_array_sharding(model.w1.value) +jax.debug.visualize_array_sharding(optimizer.momentum.w1.value) + + +@nnx.jit +def train_step(model: MLP, optimizer: SGD, x, y): + def loss_fn(model): + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return loss + + loss, grad = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grad) + return loss + + +X = np.linspace(-2, 2, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size, num_steps): + for _ in range(num_steps): + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +losses = [] +for step, (x_batch, y_batch) in enumerate( + dataset(batch_size=32, num_steps=10_000) +): + x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data')) + loss = train_step(model, optimizer, x_batch, y_batch) + losses.append(float(loss)) + if step % 1000 == 0: + print(f'Step {step}: Loss = {loss}') + +plt.figure() +plt.plot(losses[20:]) + +y_pred = model(X) +plt.figure() +plt.scatter(X, Y, color='blue') +plt.plot(X, y_pred, color='black') +plt.show() diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 04554ea7d4..367fe7e092 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -55,6 +55,7 @@ from .graph import split_context as split_context from .graph import MergeContext as MergeContext from .graph import merge_context as merge_context +from .graph import variables as variables from .nn import initializers as initializers from .nn.activations import celu as celu from .nn.activations import elu as elu @@ -116,7 +117,7 @@ from .spmd import with_sharding_constraint as with_sharding_constraint from .statelib import State as State from .training import metrics as metrics -from .variables import ( +from .variablelib import ( Param as Param, ) # this needs to be imported before optimizer to prevent circular import @@ -143,14 +144,14 @@ from .transforms.transforms import eval_shape as eval_shape from .transforms.transforms import cond as cond from .transforms.iteration import StateAxes as StateAxes -from .variables import A as A -from .variables import BatchStat as BatchStat -from .variables import Cache as Cache -from .variables import Intermediate as Intermediate -from .variables import Variable as Variable -from .variables import VariableState as VariableState -from .variables import VariableMetadata as VariableMetadata -from .variables import with_metadata as with_metadata +from .variablelib import A as A +from .variablelib import BatchStat as BatchStat +from .variablelib import Cache as Cache +from .variablelib import Intermediate as Intermediate +from .variablelib import Variable as Variable +from .variablelib import VariableState as VariableState +from .variablelib import VariableMetadata as VariableMetadata +from .variablelib import with_metadata as with_metadata from .visualization import display as display from .extract import to_tree as to_tree from .extract import from_tree as from_tree diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 3e799bf4db..93531bb485 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -20,7 +20,7 @@ from flax.core import meta from flax.nnx import spmd from flax.nnx import traversals -from flax.nnx import variables as variableslib +from flax.nnx import variablelib as variableslib from flax.nnx.module import GraphDef import typing as tp diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index 2e4de1a178..9ad3419dfd 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -54,6 +54,16 @@ def to_predicate(filter: Filter) -> Predicate: else: raise TypeError(f'Invalid collection filter: {filter:!r}. ') +def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]: + for i, filter_ in enumerate(filters): + if filter_ in (..., True) and i != len(filters) - 1: + remaining_filters = filters[i + 1 :] + if not all(f in (..., True) for f in remaining_filters): + raise ValueError( + '`...` or `True` can only be used as the last filters, ' + f'got {filter_} it at index {i}.' + ) + return tuple(map(to_predicate, filters)) @dataclasses.dataclass(frozen=True) class WithTag: diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 65eccfa906..97566dedd8 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -32,7 +32,8 @@ DelayedAccessor, ) from flax.nnx.statelib import FlatState, State -from flax.nnx.variables import Variable, VariableState +from flax.nnx import variablelib +from flax.nnx.variablelib import Variable, VariableState from flax.typing import Key, PathParts A = tp.TypeVar('A') @@ -1325,15 +1326,47 @@ def update(node, state: State, /, *states: State) -> None: _graph_update_dynamic(node, state.raw_mapping) +def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]: + for path, value in iter_graph(node): + if isinstance(value, Variable): + yield path, value + @tp.overload -def state(node, /) -> GraphState: ... +def variables(node, /) -> State[Key, Variable]: ... +@tp.overload +def variables(node, first: filterlib.Filter, /) -> State[Key, Variable]: ... +@tp.overload +def variables( + node, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, +) -> tuple[State[Key, Variable], ...]: ... +def variables( + node, + *filters: filterlib.Filter, +) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]: + num_filters = len(filters) + if num_filters == 0: + filters = (..., ...) + else: + filters = (*filters, ...) + variables_iterable = _variables_generator(node) + flat_states = variablelib.split_flat_state( + variables_iterable, (*filters, ...) + ) + states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + if num_filters < 2: + return states[0] + return states +@tp.overload +def state(node, /) -> GraphState: ... @tp.overload def state(node, first: filterlib.Filter, /) -> GraphState: ... - - @tp.overload def state( node, @@ -1342,8 +1375,6 @@ def state( /, *filters: filterlib.Filter, ) -> tuple[GraphState, ...]: ... - - def state( node, *filters: filterlib.Filter, diff --git a/flax/nnx/module.py b/flax/nnx/module.py index efada835a7..795bb9a088 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -23,7 +23,7 @@ filterlib, graph, ) -from flax.nnx import variables as variableslib +from flax.nnx import variablelib as variableslib from flax.nnx.graph import GraphDef from flax.nnx.object import Object, ObjectMeta from flax.nnx.graph import GraphState, StateLeaf diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index dd6a18a56b..364b5dac1e 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict from flax import nnx -from flax.nnx import rnglib, variables +from flax.nnx import rnglib, variablelib from flax.nnx.module import Module, first_from from flax.nnx.nn import dtypes, initializers from flax.typing import ( @@ -193,7 +193,7 @@ def kernel_init_wrap(rng, shape, dtype): ) flat_shape = jax.tree.map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) - if isinstance(kernel, variables.VariableMetadata): + if isinstance(kernel, variablelib.VariableMetadata): kernel.raw_value = jnp.reshape(kernel.raw_value, shape) else: kernel = jnp.reshape(kernel, shape) @@ -215,7 +215,7 @@ def kernel_init_wrap(rng, shape, dtype): def bias_init_wrap(rng, shape, dtype): flat_shape = (int(np.prod(shape)),) bias = self.bias_init(rng, flat_shape, dtype) - if isinstance(bias, variables.VariableMetadata): + if isinstance(bias, variablelib.VariableMetadata): bias.raw_value = jnp.reshape(bias.raw_value, shape) else: bias = jnp.reshape(bias, shape) diff --git a/flax/nnx/nn/lora.py b/flax/nnx/nn/lora.py index 6fe5984e7e..dbba23fd1d 100644 --- a/flax/nnx/nn/lora.py +++ b/flax/nnx/nn/lora.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp -from flax.nnx import rnglib, variables +from flax.nnx import rnglib, variablelib from flax.nnx.module import Module from flax.nnx.nn import initializers from flax.nnx.nn.linear import Linear @@ -32,7 +32,7 @@ default_kernel_init = initializers.lecun_normal() -class LoRAParam(variables.Param[A]): +class LoRAParam(variablelib.Param[A]): pass @@ -84,7 +84,7 @@ def __init__( dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, kernel_init: Initializer = default_kernel_init, - lora_param_type: tp.Type[variables.Variable] = LoRAParam, + lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, rngs: rnglib.Rngs, ): self.in_features = in_features @@ -155,7 +155,7 @@ def __init__( lora_dtype: tp.Optional[Dtype] = None, lora_param_dtype: Dtype = jnp.float32, lora_kernel_init: Initializer = default_kernel_init, - lora_param_type: tp.Type[variables.Variable] = LoRAParam, + lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, rngs: rnglib.Rngs, **kwargs, ): diff --git a/flax/nnx/object.py b/flax/nnx/object.py index f2714ff7fd..c63506fc48 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -29,7 +29,7 @@ tracers, ) from flax.nnx import graph -from flax.nnx.variables import Variable, VariableState +from flax.nnx.variablelib import Variable, VariableState from flax.typing import Key from flax import errors diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 25b2eea4fb..17bbaf37c8 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -23,7 +23,7 @@ from flax import struct from flax.nnx import graph from flax.nnx.statelib import State -from flax.nnx.variables import Variable +from flax.nnx.variablelib import Variable from flax.nnx import filterlib from flax.nnx.filterlib import All from flax.nnx.object import Object diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 822e24c49e..fd9deb89f8 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -19,7 +19,7 @@ from jax.interpreters import pxla from jax.sharding import PartitionSpec -from flax.nnx import variables +from flax.nnx import variablelib from flax.typing import ( Array, ArrayPytree, # pylint: disable=invalid-name @@ -36,7 +36,7 @@ def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A: axis_name = _get_partition_name(params) def _add_axis(x: tp.Any): - if isinstance(x, variables.VariableState): + if isinstance(x, variablelib.VariableState): if hasattr(x, 'sharding') and x.sharding is not None: sharding: list[str | None] = list(x.sharding) while len(sharding) < index: @@ -48,7 +48,7 @@ def _add_axis(x: tp.Any): return x return jax.tree.map( - _add_axis, tree, is_leaf=lambda x: isinstance(x, variables.VariableState) + _add_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.VariableState) ) @@ -56,7 +56,7 @@ def remove_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A: axis_name = _get_partition_name(params) def _remove_axis(x: tp.Any): - if isinstance(x, variables.VariableState): + if isinstance(x, variablelib.VariableState): if hasattr(x, 'sharding') and x.sharding is not None: sharding = list(x.sharding) assert sharding.pop(index) == axis_name @@ -67,7 +67,7 @@ def _remove_axis(x: tp.Any): return jax.tree.map( _remove_axis, tree, - is_leaf=lambda x: isinstance(x, variables.VariableState), + is_leaf=lambda x: isinstance(x, variablelib.VariableState), ) @@ -94,7 +94,7 @@ def from_rules(sharding, sharding_rules): return (rules[s] if s in rules else None for s in sharding) def f(x): - if isinstance(x, (variables.VariableState, variables.Variable)): + if isinstance(x, (variablelib.VariableState, variablelib.Variable)): if hasattr(x, 'sharding') and x.sharding: if hasattr(x, 'sharding_rules') and x.sharding_rules: return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules))) @@ -105,7 +105,7 @@ def f(x): return _maybe_replicate(x) return jax.tree.map( - f, tree, is_leaf=lambda x: isinstance(x, variables.VariableState) + f, tree, is_leaf=lambda x: isinstance(x, variablelib.VariableState) ) @@ -171,7 +171,7 @@ def with_partitioning( mesh: tp.Optional[jax.sharding.Mesh] = None, **metadata: tp.Any, ) -> F: - return variables.with_metadata( + return variablelib.with_metadata( initializer, sharding=sharding, mesh=mesh, diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 9063bc8196..4b3b1e387e 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -146,6 +146,12 @@ def __treescope_repr__(self, path, subtree_renderer): subtree_renderer=subtree_renderer, ) + def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: + flat_state = self.flat_state() + for path, variable_state in flat_state.items(): + flat_state[path] = f(path, variable_state) + return State.from_flat_path(flat_state) + def flat_state(self) -> FlatState[V]: return traversals.flatten_mapping(self._mapping) @@ -418,4 +424,4 @@ def _split_state( # if we didn't break, set leaf to last state flat_states[-1][path] = value # type: ignore[index] # mypy is wrong here? - return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) \ No newline at end of file diff --git a/flax/nnx/training/metrics.py b/flax/nnx/training/metrics.py index 492691349f..2073787b0d 100644 --- a/flax/nnx/training/metrics.py +++ b/flax/nnx/training/metrics.py @@ -20,7 +20,7 @@ from flax import struct from flax.nnx import filterlib, graph from flax.nnx.object import Object -from flax.nnx.variables import Variable +from flax.nnx.variablelib import Variable import jax, jax.numpy as jnp # TODO: add tests and docstrings diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 281066ea42..fc3b4eeb15 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -19,9 +19,9 @@ from flax import nnx from flax.nnx import filterlib -from flax.nnx import variables +from flax.nnx import variablelib from flax.nnx.object import Object -from flax.nnx.variables import Variable, VariableState +from flax.nnx.variablelib import Variable, VariableState # TODO: add tests and docstrings @@ -47,7 +47,7 @@ class OptVariable(OptState): def _wrap_optimizer_state(opt_state): def wrap_optimizer_state_fn(x): - if isinstance(x, variables.VariableState): + if isinstance(x, variablelib.VariableState): new_state = x.copy() new_state.source_type = x.type new_state.type = OptVariable @@ -58,7 +58,7 @@ def wrap_optimizer_state_fn(x): return jax.tree.map( wrap_optimizer_state_fn, opt_state, - is_leaf=lambda x: isinstance(x, variables.VariableState), + is_leaf=lambda x: isinstance(x, variablelib.VariableState), ) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 9e55f70906..663b9a8ef6 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -23,7 +23,7 @@ extract, filterlib, graph, - variables, + variablelib, ) from flax.nnx.statelib import State import jax @@ -126,7 +126,7 @@ def _grad_general( index_filter[index] = ( dataclasses.replace(argnum, argnum=-1) if isinstance(argnum, DiffState) - else DiffState(-1, variables.Param) + else DiffState(-1, variablelib.Param) ) gradded_fn = transform( diff --git a/flax/nnx/transforms/deprecated.py b/flax/nnx/transforms/deprecated.py index f0191fc020..844cea4858 100644 --- a/flax/nnx/transforms/deprecated.py +++ b/flax/nnx/transforms/deprecated.py @@ -20,7 +20,7 @@ from flax import struct from flax.core.frozen_dict import FrozenDict -from flax.nnx import extract, filterlib, graph, rnglib, spmd, variables +from flax.nnx import extract, filterlib, graph, rnglib, spmd, variablelib from flax.nnx.module import GraphDef, Module from flax.nnx.proxy_caller import DelayedAccessor from flax.nnx.statelib import State @@ -1685,7 +1685,7 @@ def grad( allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, ) -> tp.Callable[..., tp.Any]: """Lifted version of ``jax.grad`` that can handle Modules / graph nodes as arguments. @@ -1770,7 +1770,7 @@ def value_and_grad( allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, ) -> tp.Callable[..., tp.Any]: return _grad_general( f, @@ -1794,7 +1794,7 @@ def constructor( reduce_axes: tp.Sequence[AxisName] = (), return_value: bool = False, *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, ) -> tp.Callable[..., Grad[MA]]: def _create_grad(*args, **kwargs): return Grad( @@ -1821,7 +1821,7 @@ def __init__( allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, # submodule args module_init_args: tuple[tp.Any, ...], module_init_kwargs: dict[str, tp.Any], diff --git a/flax/nnx/variables.py b/flax/nnx/variablelib.py similarity index 96% rename from flax/nnx/variables.py rename to flax/nnx/variablelib.py index 882eeb4a6c..26ef67745c 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variablelib.py @@ -23,8 +23,8 @@ import jax from flax import errors -from flax.nnx import reprlib, tracers -from flax.typing import Missing +from flax.nnx import filterlib, reprlib, tracers +from flax.typing import Missing, PathParts import jax.tree_util as jtu A = tp.TypeVar('A') @@ -245,6 +245,12 @@ def _setattr(self, name: str, value: tp.Any): def state(cls, value: A, **metadata) -> VariableState[A]: return cls(value, **metadata).to_state() + def get_metadata(self): + metadata = vars(self).copy() + del metadata['raw_value'] + del metadata['_trace_state'] + return metadata + def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): raise ValueError( @@ -960,3 +966,29 @@ def wrapper(*args): ) return wrapper # type: ignore + + +def split_flat_state( + flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]], + filters: tuple[filterlib.Filter, ...], +) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]: + predicates = filterlib.filters_to_predicates(filters) + # we have n + 1 states, where n is the number of predicates + # the last state is for values that don't match any predicate + flat_states: tuple[list[tuple[PathParts, Variable | VariableState]], ...] = ( + tuple([] for _ in predicates) + ) + + for path, value in flat_state: + for i, predicate in enumerate(predicates): + if predicate(path, value): + flat_states[i].append((path, value)) + break + else: + raise ValueError( + 'Non-exhaustive filters, got a non-empty remainder: ' + f'{path} -> {value}.' + '\nUse `...` to match all remaining elements.' + ) + + return flat_states