Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add performance guide notebook #4384

Merged
merged 1 commit into from
Nov 26, 2024
Merged

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Nov 15, 2024

What does this PR do?

  • Add the Performance Considerations guide
  • Updates the Flax Basics guide with a note of nnx.jit. Also reverts some formatting style changes made recently.
  • Adds treescope as a direct dependency of Flax.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae force-pushed the nnx-performance-guide branch 2 times, most recently from c99c64d to dbc5667 Compare November 19, 2024 14:01
@cgarciae cgarciae force-pushed the nnx-performance-guide branch 5 times, most recently from 656b7bc to fa4edf9 Compare November 25, 2024 14:02
@@ -12,27 +12,12 @@ jupytext:

Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.

In this guide you will learn about:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove these bullet points?

Copy link
Collaborator

@8bitmp3 8bitmp3 Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guide is an introduction to a lot of new concepts and covers a lot of material. I will add the table of contents to help new users @IvyZX @cgarciae .

@@ -43,17 +28,9 @@ import jax.numpy as jnp

## The Flax NNX Module system

The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that:

1) The [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why also remove bullet points here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've converted them to a single paragraph. I don't like excessive use of bullet points.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK if I add back the URLs to certain API ref docs like nnx.Param? Also, we are mixing JAX and Flax APIs here, so an external link to jax.Array can help less experienced users.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a Flax vs Haiku comparison, so for ease of reading and to aid "scannability", I recommend using a list (without sequential numbering). A wall of text here may be missed by some users.

https://developers.google.com/style/accessibility#ease-of-reading

loss = train_step(model, optimizer, metrics, x, y)
```

To speed it up we can use `nnx.split` before starting the training loop to create a `graphdef` and `state` pytrees for the Flax NNX objects as a group since `graphdef` and `state` are fast to traverse. Then at the beggining and end of a `jax.jit`-decorated function we can call `nnx.merge` and `nnx.split` to switch back and forth between the object and pytree representations. The important thing here is that `split` and `merge` will only run once during tracing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beggining -> beginning

loss = train_step(model, optimizer, metrics, x, y)
```

To speed it up we can use `nnx.split` before starting the training loop to create a `graphdef` and `state` pytrees for the Flax NNX objects as a group since `graphdef` and `state` are fast to traverse. Then at the beggining and end of a `jax.jit`-decorated function we can call `nnx.merge` and `nnx.split` to switch back and forth between the object and pytree representations. The important thing here is that `split` and `merge` will only run once during tracing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a line here to explicitly say that you need to run nnx.merge and nnx.split on all the NNX objects that are part of the train step inputs.

@cgarciae cgarciae force-pushed the nnx-performance-guide branch from fa4edf9 to 94793f9 Compare November 25, 2024 21:19
@copybara-service copybara-service bot merged commit abc1155 into main Nov 26, 2024
19 checks passed
@copybara-service copybara-service bot deleted the nnx-performance-guide branch November 26, 2024 08:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants