-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] use jax-style transforms API in nnx_basics #4155
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -225,22 +225,24 @@ input (scan over layers). | |||
|
|||
Notice the following: | |||
1. The `create_model` function creates a (single) `MLP` object that is lifted by | |||
`nnx.vmap` to have an additional dimension of size `axis_size`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say more explicitly that 5 MLP layers are created because keys
has 5 dimensions, and vmap inferred the axis_size
from it? This now looks a bit confusing b/c the user would have to guess how do we customize the number of layers here.
I actually still prefer if we can explicitly customize the number of layers in nnx.vmap
line, instead of implicitly at the keys
line.
docs/nnx/nnx_basics.md
Outdated
2. The `forward` function indexes the `MLP` object's state to get a different set of | ||
parameters at each step. | ||
3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and | ||
`Dropout` layers from within `forward` to the `model` reference outside. | ||
3. The `nnx.scan` transform consciously deviates from its JAX equivalent in order to mimick |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add links for the nnx.scan
, nnx.vmap
, jax.scan
and jax.vmap
API docs? Most people reading this will be new to JAX/Flax and don't know these subjects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I will add add the reference on their first appearance above.
def create_model(rngs: nnx.Rngs): | ||
return MLP(10, 32, 10, rngs=rngs) | ||
keys = jax.random.split(jax.random.key(0), 5) | ||
model = create_model(keys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a line that print out the shape of model
's params? To show this additional axis.
@@ -225,22 +225,24 @@ input (scan over layers). | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'd want to explain a bit about scan here. What about changing this line to:
Next lets take a look at a different example, which uses
nnx.vmap
to create a stack of multipleMLP
layers andnnx.scan
to iteratively apply each layer of the stack to the input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
0d982c5
to
a1ab608
Compare
a1ab608
to
8ecd6e9
Compare
8ecd6e9
to
1a78acb
Compare
What does this PR do?
Updates
nnx_basics
to use the new JAX-like transforms syntax in the Transforms section. Also adds some additional notes aboutnnx.scan
.