-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] use jax-style transforms API in nnx_basics #4155
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,28 +219,31 @@ Theres a couple of things happening in this example that are worth mentioning: | |
using the optimizer alone. | ||
|
||
#### Scan over layers | ||
Next lets take a look at a different example using `nnx.vmap` to create an | ||
`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the | ||
input (scan over layers). | ||
Next lets take a look at a different example, which uses | ||
[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) | ||
to create a stack of multiple MLP layers and | ||
[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) | ||
to iteratively apply each layer of the stack to the input. | ||
|
||
Notice the following: | ||
1. The `create_model` function creates a (single) `MLP` object that is lifted by | ||
`nnx.vmap` to have an additional dimension of size `axis_size`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you say more explicitly that 5 MLP layers are created because I actually still prefer if we can explicitly customize the number of layers in |
||
2. The `forward` function indexes the `MLP` object's state to get a different set of | ||
parameters at each step. | ||
3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and | ||
`Dropout` layers from within `forward` to the `model` reference outside. | ||
1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys | ||
and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created. | ||
2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`. | ||
3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is | ||
more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output, | ||
and the position of the carry. | ||
4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated | ||
by `nnx.scan`. | ||
|
||
```{code-cell} ipython3 | ||
from functools import partial | ||
@nnx.vmap(in_axes=0, out_axes=0) | ||
def create_model(key: jax.Array): | ||
return MLP(10, 32, 10, rngs=nnx.Rngs(key)) | ||
@partial(nnx.vmap, axis_size=5) | ||
def create_model(rngs: nnx.Rngs): | ||
return MLP(10, 32, 10, rngs=rngs) | ||
keys = jax.random.split(jax.random.key(0), 5) | ||
model = create_model(keys) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a line that print out the shape of |
||
model = create_model(nnx.Rngs(0)) | ||
@nnx.scan | ||
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0) | ||
def forward(x, model: MLP): | ||
x = model(x) | ||
return x, None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'd want to explain a bit about scan here. What about changing this line to:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!