Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a guide for nnx.bridge #4171

Merged
merged 1 commit into from
Sep 9, 2024
Merged

Add a guide for nnx.bridge #4171

merged 1 commit into from
Sep 9, 2024

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Sep 5, 2024

Todo for future guides:

  • When we have a guide for NNX's JAX-style transform, cross-reference it here.
  • When we have a standalone Linen-to-NNX (refactored from the current Haiku/Linen/NNX guide), cross reference it here.

@IvyZX IvyZX requested review from levskaya and cgarciae September 5, 2024 00:42
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

* Want to migrate your codebase to NNX gradually, one module at a time;
* Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX.

We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the coveats of interoping the two APIs, on a few aspects that they are fundamentally different.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the coveats of interoping the two APIs, on a few aspects that they are fundamentally different.
We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Collaborator

@8bitmp3 8bitmp3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First observations, LMKWYT


**Note**:

This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Linen to NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide.
Copy link
Collaborator

@8bitmp3 8bitmp3 Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "Flax Linen" and "Flax NNX" if that makes sense


This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Linen to NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide.

And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: To help new Flax users - write "Flax Linen"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,544 @@
# Use NNX along with Linen
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: To emphasize that this is still Flax

Suggested change
# Use NNX along with Linen
# Use Flax NNX alongside Flax Linen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,544 @@
# Use NNX along with Linen

This guide is for existing Flax users who want to make their codebase a mixture of Linen and NNX modules, which is made possible by the `nnx.bridge` API.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

Suggested change
This guide is for existing Flax users who want to make their codebase a mixture of Linen and NNX modules, which is made possible by the `nnx.bridge` API.
This guide is for existing Flax users who want to make their codebase a mixture of Linen and NNX `Module`s, which is made possible thanks to the `nnx.bridge` API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return x @ w

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(LinenDot(64), rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be interesting to do an nnx.display here after construction to visualize the uninitialized structure, and in a new cell call the rest and visualize the module again after its initialized.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Done.


We don't let you create an NNX module of your own because new variables are created eagerly as soon as a NNX module is created. Instead, `bridge.ToLinen` will handle it for you.

An NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. To bypass this issue, you should send your creation arguments (instead of a created NNX module) to `bridge.to_linen`, and let it handle the rest.
Copy link
Collaborator

@cgarciae cgarciae Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of repetition here on the first phrase and the previous paragraph.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return x @ self.w

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToLinen(NNXDot, args=(32, 64)) # <- Pass in the arguments, not an actual module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use to_linen as its both more ergonomic and does the FrozenDict conversion automatically. Most users will copy the first example they see.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



```python
bridge.to_linen(NNXDot, 32, out_dim=64) == bridge.ToLinen(NNXDot, args=(32,), kwargs={'out_dim': 64})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that to_linen should be the main API. Lets not show 2 ways of doing this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Explained a use case where you must use ToLinen.

<class 'flax.nnx.nnx.graph.NodeDef'>


You can use a shortcut `bridge.to_linen` to avoid explicitly grouping args and kwargs separately. You only need to use the underlying `bridge.ToLinen()` in some rare cases.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can also forward the other arguments such that you can always use to_linen?


All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys.

Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves. This means a bit difference in both `ToNNX` and `ToLinen`:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means a bit difference in both `ToNNX` and `ToLinen`

This part is a bit unclear.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

# Reset the dropout RNG seed, so that next model run will be the same as the first.
nnx.reseed(model, dropout=0)
assert jnp.allclose(y1, model(x))
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this example!

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 6, 2024

One general thing I'd add is the use of nnx.display on most cells so users can visualize the structure, it makes it a bit more intuitive to follow e.g:

Screenshot 2024-09-06 at 11 01 56

Comment on lines 259 to 386
print('How many times we have created new RNG keys from it:', variables['RngCount']['rngs']['dropout']['count'].value)

# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`
y1, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
y2, updates = model.apply(variables, x, mutable=['RngCount'])
assert not jnp.allclose(y1, y2) # Every call yields different output!
Copy link
Collaborator

@cgarciae cgarciae Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add another split count print to show that its incrementing

Suggested change
print('How many times we have created new RNG keys from it:', variables['RngCount']['rngs']['dropout']['count'].value)
# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`
y1, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
y2, updates = model.apply(variables, x, mutable=['RngCount'])
assert not jnp.allclose(y1, y2) # Every call yields different output!
print('Key split count:', variables['RngCount']['rngs']['dropout']['count'].value)
# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`
y1, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
y2, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
print('Key split count:', variables['RngCount']['rngs']['dropout']['count'].value)
assert not jnp.allclose(y1, y2) # Every call yields different output!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great! Done


# Linen style: Just pass different RNG keys for every `apply()` call.
y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})
Copy link
Collaborator

@cgarciae cgarciae Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be interesting to do the opposite and use the same key on both calls to show that this overrides the keys and yields the same output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! done.


```python
class Count(nnx.Variable): pass
nnx.register_variable_name_type_pair('Count', Count, overwrite=True)
Copy link
Collaborator

@cgarciae cgarciae Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Count would be the automatic name anyway, maybe we can show registering a Linen-style name (lower case + plural)?

Suggested change
nnx.register_variable_name_type_pair('Count', Count, overwrite=True)
nnx.register_variable_name_type_pair('counts', Count, overwrite=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 418 to 423
@jax.jit
def create_sharded_nnx_module(x):
model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)
static, state = nnx.split(model)
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
return static, sharded_state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using nnx.jit here so you don't have to use merge later.

Suggested change
@jax.jit
def create_sharded_nnx_module(x):
model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)
static, state = nnx.split(model)
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
return static, sharded_state
@nnx.jit
def create_sharded_nnx_module(x):
model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)
state = nnx.state(model)
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
nnx.update(model, sharded_state)
return model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I should get more used to it!

mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), axis_names=('in', 'out'))
x = jax.random.normal(jax.random.key(42), (4, 32))
with mesh:
model = nnx.merge(*create_sharded_nnx_module(x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If the previous is implemented)

Suggested change
model = nnx.merge(*create_sharded_nnx_module(x))
model = create_sharded_nnx_module(x)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,669 @@
# Use Flax NNX along with Flax Linen

This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.
Copy link
Member

@ayaka14732 ayaka14732 Sep 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: There are redundant trailing spaces on several lines.

Suggested change
This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.
This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ayaka! It seems that our internal & external code checkers doesn't flag out trailing spaces in notebook markdowns so I got lazy and didn't fix them... Will keep an eye out on the next PR!

@copybara-service copybara-service bot merged commit d373856 into google:main Sep 9, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants