Skip to content

Commit

Permalink
[nnx] use jax-like transforms API in nnx_basics
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 2, 2024
1 parent 839db8c commit 8ecd6e9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 32 deletions.
35 changes: 19 additions & 16 deletions docs/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,21 @@
" using the optimizer alone.\n",
"\n",
"#### Scan over layers\n",
"Next lets take a look at a different example using `nnx.vmap` to create an\n",
"`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the\n",
"input (scan over layers). \n",
"Next lets take a look at a different example, which uses\n",
"[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)\n",
"to create a stack of multiple MLP layers and\n",
"[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)\n",
"to iteratively apply each layer of the stack to the input.\n",
"\n",
"Notice the following:\n",
"1. The `create_model` function creates a (single) `MLP` object that is lifted by\n",
" `nnx.vmap` to have an additional dimension of size `axis_size`.\n",
"2. The `forward` function indexes the `MLP` object's state to get a different set of\n",
" parameters at each step.\n",
"3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and \n",
"`Dropout` layers from within `forward` to the `model` reference outside."
"1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys\n",
" and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n",
"2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.\n",
"3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is\n",
" more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,\n",
" and the position of the carry.\n",
"4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated \n",
" by `nnx.scan`."
]
},
{
Expand Down Expand Up @@ -404,15 +408,14 @@
}
],
"source": [
"from functools import partial\n",
"@nnx.vmap(in_axes=0, out_axes=0)\n",
"def create_model(key: jax.Array):\n",
" return MLP(10, 32, 10, rngs=nnx.Rngs(key))\n",
"\n",
"@partial(nnx.vmap, axis_size=5)\n",
"def create_model(rngs: nnx.Rngs):\n",
" return MLP(10, 32, 10, rngs=rngs)\n",
"keys = jax.random.split(jax.random.key(0), 5)\n",
"model = create_model(keys)\n",
"\n",
"model = create_model(nnx.Rngs(0))\n",
"\n",
"@nnx.scan\n",
"@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)\n",
"def forward(x, model: MLP):\n",
" x = model(x)\n",
" return x, None\n",
Expand Down
35 changes: 19 additions & 16 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,28 +219,31 @@ Theres a couple of things happening in this example that are worth mentioning:
using the optimizer alone.

#### Scan over layers
Next lets take a look at a different example using `nnx.vmap` to create an
`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the
input (scan over layers).
Next lets take a look at a different example, which uses
[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)
to create a stack of multiple MLP layers and
[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)
to iteratively apply each layer of the stack to the input.

Notice the following:
1. The `create_model` function creates a (single) `MLP` object that is lifted by
`nnx.vmap` to have an additional dimension of size `axis_size`.
2. The `forward` function indexes the `MLP` object's state to get a different set of
parameters at each step.
3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and
`Dropout` layers from within `forward` to the `model` reference outside.
1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys
and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.
2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.
3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is
more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,
and the position of the carry.
4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated
by `nnx.scan`.

```{code-cell} ipython3
from functools import partial
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
@partial(nnx.vmap, axis_size=5)
def create_model(rngs: nnx.Rngs):
return MLP(10, 32, 10, rngs=rngs)
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
model = create_model(nnx.Rngs(0))
@nnx.scan
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)
def forward(x, model: MLP):
x = model(x)
return x, None
Expand Down

0 comments on commit 8ecd6e9

Please sign in to comment.