-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a guide for nnx.bridge
#4171
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
docs/nnx/bridge_guide.md
Outdated
* Want to migrate your codebase to NNX gradually, one module at a time; | ||
* Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX. | ||
|
||
We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the coveats of interoping the two APIs, on a few aspects that they are fundamentally different. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the coveats of interoping the two APIs, on a few aspects that they are fundamentally different. | |
We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First observations, LMKWYT
docs/nnx/bridge_guide.md
Outdated
|
||
**Note**: | ||
|
||
This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Linen to NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "Flax Linen" and "Flax NNX" if that makes sense
|
||
This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Linen to NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. | ||
|
||
And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: To help new Flax users - write "Flax Linen"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs/nnx/bridge_guide.md
Outdated
@@ -0,0 +1,544 @@ | |||
# Use NNX along with Linen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: To emphasize that this is still Flax
# Use NNX along with Linen | |
# Use Flax NNX alongside Flax Linen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs/nnx/bridge_guide.md
Outdated
@@ -0,0 +1,544 @@ | |||
# Use NNX along with Linen | |||
|
|||
This guide is for existing Flax users who want to make their codebase a mixture of Linen and NNX modules, which is made possible by the `nnx.bridge` API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
This guide is for existing Flax users who want to make their codebase a mixture of Linen and NNX modules, which is made possible by the `nnx.bridge` API. | |
This guide is for existing Flax users who want to make their codebase a mixture of Linen and NNX `Module`s, which is made possible thanks to the `nnx.bridge` API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return x @ w | ||
|
||
x = jax.random.normal(jax.random.key(42), (4, 32)) | ||
model = bridge.ToNNX(LinenDot(64), rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be interesting to do an nnx.display
here after construction to visualize the uninitialized structure, and in a new cell call the rest and visualize the module again after its initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! Done.
docs/nnx/bridge_guide.md
Outdated
|
||
We don't let you create an NNX module of your own because new variables are created eagerly as soon as a NNX module is created. Instead, `bridge.ToLinen` will handle it for you. | ||
|
||
An NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. To bypass this issue, you should send your creation arguments (instead of a created NNX module) to `bridge.to_linen`, and let it handle the rest. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bit of repetition here on the first phrase and the previous paragraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs/nnx/bridge_guide.md
Outdated
return x @ self.w | ||
|
||
x = jax.random.normal(jax.random.key(42), (4, 32)) | ||
model = bridge.ToLinen(NNXDot, args=(32, 64)) # <- Pass in the arguments, not an actual module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use to_linen
as its both more ergonomic and does the FrozenDict
conversion automatically. Most users will copy the first example they see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs/nnx/bridge_guide.md
Outdated
|
||
|
||
```python | ||
bridge.to_linen(NNXDot, 32, out_dim=64) == bridge.ToLinen(NNXDot, args=(32,), kwargs={'out_dim': 64}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that to_linen
should be the main API. Lets not show 2 ways of doing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Explained a use case where you must use ToLinen
.
docs/nnx/bridge_guide.md
Outdated
<class 'flax.nnx.nnx.graph.NodeDef'> | ||
|
||
|
||
You can use a shortcut `bridge.to_linen` to avoid explicitly grouping args and kwargs separately. You only need to use the underlying `bridge.ToLinen()` in some rare cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can also forward the other arguments such that you can always use to_linen
?
docs/nnx/bridge_guide.md
Outdated
|
||
All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys. | ||
|
||
Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves. This means a bit difference in both `ToNNX` and `ToLinen`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means a bit difference in both `ToNNX` and `ToLinen`
This part is a bit unclear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
# Reset the dropout RNG seed, so that next model run will be the same as the first. | ||
nnx.reseed(model, dropout=0) | ||
assert jnp.allclose(y1, model(x)) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this example!
docs/nnx/bridge_guide.md
Outdated
print('How many times we have created new RNG keys from it:', variables['RngCount']['rngs']['dropout']['count'].value) | ||
|
||
# NNX style: Must set `RngCount` as mutable and update the variables after every `apply` | ||
y1, updates = model.apply(variables, x, mutable=['RngCount']) | ||
variables |= updates | ||
y2, updates = model.apply(variables, x, mutable=['RngCount']) | ||
assert not jnp.allclose(y1, y2) # Every call yields different output! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could add another split count print to show that its incrementing
print('How many times we have created new RNG keys from it:', variables['RngCount']['rngs']['dropout']['count'].value) | |
# NNX style: Must set `RngCount` as mutable and update the variables after every `apply` | |
y1, updates = model.apply(variables, x, mutable=['RngCount']) | |
variables |= updates | |
y2, updates = model.apply(variables, x, mutable=['RngCount']) | |
assert not jnp.allclose(y1, y2) # Every call yields different output! | |
print('Key split count:', variables['RngCount']['rngs']['dropout']['count'].value) | |
# NNX style: Must set `RngCount` as mutable and update the variables after every `apply` | |
y1, updates = model.apply(variables, x, mutable=['RngCount']) | |
variables |= updates | |
y2, updates = model.apply(variables, x, mutable=['RngCount']) | |
variables |= updates | |
print('Key split count:', variables['RngCount']['rngs']['dropout']['count'].value) | |
assert not jnp.allclose(y1, y2) # Every call yields different output! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great! Done
|
||
# Linen style: Just pass different RNG keys for every `apply()` call. | ||
y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)}) | ||
y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be interesting to do the opposite and use the same key on both calls to show that this overrides the keys and yields the same output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! done.
docs/nnx/bridge_guide.md
Outdated
|
||
```python | ||
class Count(nnx.Variable): pass | ||
nnx.register_variable_name_type_pair('Count', Count, overwrite=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Count
would be the automatic name anyway, maybe we can show registering a Linen-style name (lower case + plural)?
nnx.register_variable_name_type_pair('Count', Count, overwrite=True) | |
nnx.register_variable_name_type_pair('counts', Count, overwrite=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs/nnx/bridge_guide.md
Outdated
@jax.jit | ||
def create_sharded_nnx_module(x): | ||
model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) | ||
static, state = nnx.split(model) | ||
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) | ||
return static, sharded_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using nnx.jit
here so you don't have to use merge
later.
@jax.jit | |
def create_sharded_nnx_module(x): | |
model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) | |
static, state = nnx.split(model) | |
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) | |
return static, sharded_state | |
@nnx.jit | |
def create_sharded_nnx_module(x): | |
model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) | |
state = nnx.state(model) | |
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) | |
nnx.update(model, sharded_state) | |
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I should get more used to it!
docs/nnx/bridge_guide.md
Outdated
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), axis_names=('in', 'out')) | ||
x = jax.random.normal(jax.random.key(42), (4, 32)) | ||
with mesh: | ||
model = nnx.merge(*create_sharded_nnx_module(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(If the previous is implemented)
model = nnx.merge(*create_sharded_nnx_module(x)) | |
model = create_sharded_nnx_module(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,669 @@ | |||
# Use Flax NNX along with Flax Linen | |||
|
|||
This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: There are redundant trailing spaces on several lines.
This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. | |
This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ayaka! It seems that our internal & external code checkers doesn't flag out trailing spaces in notebook markdowns so I got lazy and didn't fix them... Will keep an eye out on the next PR!
Todo for future guides: