From 1a78acb5ea7ba850e8e8aa0e0f395d0badbe134a Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 2 Sep 2024 15:20:22 +0100 Subject: [PATCH] [nnx] use jax-like transforms API in nnx_basics --- docs/nnx/nnx_basics.ipynb | 35 +++++++++++++++++++---------------- docs/nnx/nnx_basics.md | 35 +++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index 2cc10b0fdf..809ef5e66f 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -365,17 +365,21 @@ " using the optimizer alone.\n", "\n", "#### Scan over layers\n", - "Next lets take a look at a different example using `nnx.vmap` to create an\n", - "`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the\n", - "input (scan over layers). \n", + "Next lets take a look at a different example, which uses\n", + "[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)\n", + "to create a stack of multiple MLP layers and\n", + "[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)\n", + "to iteratively apply each layer of the stack to the input.\n", "\n", "Notice the following:\n", - "1. The `create_model` function creates a (single) `MLP` object that is lifted by\n", - " `nnx.vmap` to have an additional dimension of size `axis_size`.\n", - "2. The `forward` function indexes the `MLP` object's state to get a different set of\n", - " parameters at each step.\n", - "3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and \n", - "`Dropout` layers from within `forward` to the `model` reference outside." + "1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys\n", + " and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n", + "2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.\n", + "3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is\n", + " more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,\n", + " and the position of the carry.\n", + "4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated\n", + " by `nnx.scan`." ] }, { @@ -404,15 +408,14 @@ } ], "source": [ - "from functools import partial\n", + "@nnx.vmap(in_axes=0, out_axes=0)\n", + "def create_model(key: jax.Array):\n", + " return MLP(10, 32, 10, rngs=nnx.Rngs(key))\n", "\n", - "@partial(nnx.vmap, axis_size=5)\n", - "def create_model(rngs: nnx.Rngs):\n", - " return MLP(10, 32, 10, rngs=rngs)\n", + "keys = jax.random.split(jax.random.key(0), 5)\n", + "model = create_model(keys)\n", "\n", - "model = create_model(nnx.Rngs(0))\n", - "\n", - "@nnx.scan\n", + "@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)\n", "def forward(x, model: MLP):\n", " x = model(x)\n", " return x, None\n", diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index 70b3ff7540..0dfc37e837 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -219,28 +219,31 @@ Theres a couple of things happening in this example that are worth mentioning: using the optimizer alone. #### Scan over layers -Next lets take a look at a different example using `nnx.vmap` to create an -`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the -input (scan over layers). +Next lets take a look at a different example, which uses +[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) +to create a stack of multiple MLP layers and +[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) +to iteratively apply each layer of the stack to the input. Notice the following: -1. The `create_model` function creates a (single) `MLP` object that is lifted by - `nnx.vmap` to have an additional dimension of size `axis_size`. -2. The `forward` function indexes the `MLP` object's state to get a different set of - parameters at each step. -3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and -`Dropout` layers from within `forward` to the `model` reference outside. +1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys + and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created. +2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`. +3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is + more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output, + and the position of the carry. +4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated + by `nnx.scan`. ```{code-cell} ipython3 -from functools import partial +@nnx.vmap(in_axes=0, out_axes=0) +def create_model(key: jax.Array): + return MLP(10, 32, 10, rngs=nnx.Rngs(key)) -@partial(nnx.vmap, axis_size=5) -def create_model(rngs: nnx.Rngs): - return MLP(10, 32, 10, rngs=rngs) +keys = jax.random.split(jax.random.key(0), 5) +model = create_model(keys) -model = create_model(nnx.Rngs(0)) - -@nnx.scan +@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0) def forward(x, model: MLP): x = model(x) return x, None