-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add performance guide notebook #4384
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
c99c64d
to
dbc5667
Compare
656b7bc
to
fa4edf9
Compare
@@ -12,27 +12,12 @@ jupytext: | |||
|
|||
Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home. | |||
|
|||
In this guide you will learn about: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove these bullet points?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -43,17 +28,9 @@ import jax.numpy as jnp | |||
|
|||
## The Flax NNX Module system | |||
|
|||
The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that: | |||
|
|||
1) The [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why also remove bullet points here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've converted them to a single paragraph. I don't like excessive use of bullet points.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it OK if I add back the URLs to certain API ref docs like nnx.Param
? Also, we are mixing JAX and Flax APIs here, so an external link to jax.Array
can help less experienced users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a Flax vs Haiku comparison, so for ease of reading and to aid "scannability", I recommend using a list (without sequential numbering). A wall of text here may be missed by some users.
https://developers.google.com/style/accessibility#ease-of-reading
docs_nnx/guides/performance.md
Outdated
loss = train_step(model, optimizer, metrics, x, y) | ||
``` | ||
|
||
To speed it up we can use `nnx.split` before starting the training loop to create a `graphdef` and `state` pytrees for the Flax NNX objects as a group since `graphdef` and `state` are fast to traverse. Then at the beggining and end of a `jax.jit`-decorated function we can call `nnx.merge` and `nnx.split` to switch back and forth between the object and pytree representations. The important thing here is that `split` and `merge` will only run once during tracing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beggining -> beginning
docs_nnx/guides/performance.md
Outdated
loss = train_step(model, optimizer, metrics, x, y) | ||
``` | ||
|
||
To speed it up we can use `nnx.split` before starting the training loop to create a `graphdef` and `state` pytrees for the Flax NNX objects as a group since `graphdef` and `state` are fast to traverse. Then at the beggining and end of a `jax.jit`-decorated function we can call `nnx.merge` and `nnx.split` to switch back and forth between the object and pytree representations. The important thing here is that `split` and `merge` will only run once during tracing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add a line here to explicitly say that you need to run nnx.merge and nnx.split on all the NNX objects that are part of the train step inputs.
fa4edf9
to
94793f9
Compare
What does this PR do?
nnx.jit
. Also reverts some formatting style changes made recently.treescope
as a direct dependency of Flax.