Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] use jax-style transforms API in nnx_basics #4155

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions docs/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,21 @@
" using the optimizer alone.\n",
"\n",
"#### Scan over layers\n",
"Next lets take a look at a different example using `nnx.vmap` to create an\n",
"`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the\n",
"input (scan over layers). \n",
"Next lets take a look at a different example, which uses\n",
"[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)\n",
"to create a stack of multiple MLP layers and\n",
"[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)\n",
"to iteratively apply each layer of the stack to the input.\n",
"\n",
"Notice the following:\n",
"1. The `create_model` function creates a (single) `MLP` object that is lifted by\n",
" `nnx.vmap` to have an additional dimension of size `axis_size`.\n",
"2. The `forward` function indexes the `MLP` object's state to get a different set of\n",
" parameters at each step.\n",
"3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and \n",
"`Dropout` layers from within `forward` to the `model` reference outside."
"1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys\n",
" and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n",
"2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.\n",
"3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is\n",
" more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,\n",
" and the position of the carry.\n",
"4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated\n",
" by `nnx.scan`."
]
},
{
Expand Down Expand Up @@ -404,15 +408,14 @@
}
],
"source": [
"from functools import partial\n",
"@nnx.vmap(in_axes=0, out_axes=0)\n",
"def create_model(key: jax.Array):\n",
" return MLP(10, 32, 10, rngs=nnx.Rngs(key))\n",
"\n",
"@partial(nnx.vmap, axis_size=5)\n",
"def create_model(rngs: nnx.Rngs):\n",
" return MLP(10, 32, 10, rngs=rngs)\n",
"keys = jax.random.split(jax.random.key(0), 5)\n",
"model = create_model(keys)\n",
"\n",
"model = create_model(nnx.Rngs(0))\n",
"\n",
"@nnx.scan\n",
"@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)\n",
"def forward(x, model: MLP):\n",
" x = model(x)\n",
" return x, None\n",
Expand Down
35 changes: 19 additions & 16 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,28 +219,31 @@ Theres a couple of things happening in this example that are worth mentioning:
using the optimizer alone.

#### Scan over layers
Next lets take a look at a different example using `nnx.vmap` to create an
`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the
input (scan over layers).
Next lets take a look at a different example, which uses
[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)
to create a stack of multiple MLP layers and
[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)
to iteratively apply each layer of the stack to the input.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd want to explain a bit about scan here. What about changing this line to:

Next lets take a look at a different example, which uses nnx.vmap to create a stack of multiple MLP layers and nnx.scan to iteratively apply each layer of the stack to the input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Notice the following:
1. The `create_model` function creates a (single) `MLP` object that is lifted by
`nnx.vmap` to have an additional dimension of size `axis_size`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more explicitly that 5 MLP layers are created because keys has 5 dimensions, and vmap inferred the axis_size from it? This now looks a bit confusing b/c the user would have to guess how do we customize the number of layers here.

I actually still prefer if we can explicitly customize the number of layers in nnx.vmap line, instead of implicitly at the keys line.

2. The `forward` function indexes the `MLP` object's state to get a different set of
parameters at each step.
3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and
`Dropout` layers from within `forward` to the `model` reference outside.
1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys
and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.
2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.
3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is
more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,
and the position of the carry.
4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated
by `nnx.scan`.

```{code-cell} ipython3
from functools import partial
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
@partial(nnx.vmap, axis_size=5)
def create_model(rngs: nnx.Rngs):
return MLP(10, 32, 10, rngs=rngs)
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a line that print out the shape of model's params? To show this additional axis.

model = create_model(nnx.Rngs(0))
@nnx.scan
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)
def forward(x, model: MLP):
x = model(x)
return x, None
Expand Down
Loading