Skip to content

Commit

Permalink
Merge pull request #4384 from google:nnx-performance-guide
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700239850
  • Loading branch information
Flax Authors committed Nov 26, 2024
2 parents be26138 + 94793f9 commit abc1155
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 85 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ flaxlib_src/build
flaxlib_src/builddir
flaxlib_src/dist
flaxlib_src/subprojects
target/
flaxlib.cpython-*

# used by direnv
.envrc

Expand Down
Binary file added docs_nnx/guides/images/performance-graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs_nnx/guides/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Guides
flax_gspmd
filters_guide
randomness
performance
linen_to_nnx
bridge_guide
surgery
Expand Down
151 changes: 151 additions & 0 deletions docs_nnx/guides/performance.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Performance Considerations\n",
"Currently `nnx.jit` traverses the object graph in pure Python, this is slow and adds overhead. To solve this in general we will be developing a Rust extension called `flaxlib` (see first steps in #4196) to speedup some of the traversal logic in `graph.py`, similar to how JAX solved the same issue with `jaxlib` for standard pytrees. However, there's two things to consider:\n",
"\n",
"* The overhead is only relevant for small models. See [Asynchronous dispatch](#asynchronous-dispatch).\n",
"* You can remove the overhead by using `jax.jit` + `nnx.split` / `nnx.merge` to stage out the traversal logic. See [Lowering the Python Overhead](#lowering-the-python-overhead).\n",
"\n",
"\n",
"## Asynchronous dispatch\n",
"In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with `nnx.jit` and `jax.jit`. As you can see in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`. \n",
"\n",
"![performance-graph](images/performance-graph.png)\n",
"\n",
"This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead.\n",
"\n",
"## Lowering the Python Overhead\n",
"To remove the python overhead you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. To learn how to do this, lets first create this simple model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from flax import nnx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"\n",
"class Model(nnx.Module):\n",
" def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n",
" self.linear = nnx.Linear(din, dmid, rngs=rngs)\n",
" self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n",
" self.dropout = nnx.Dropout(0.2, rngs=rngs)\n",
" self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)\n",
"\n",
" def __call__(self, x):\n",
" x = nnx.relu(self.dropout(self.bn(self.linear(x))))\n",
" return self.linear_out(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets say we have this `train_step` function that is using `nnx.jit` and takes in a `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization\n",
"optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing\n",
"metrics = nnx.MultiMetric(\n",
" loss=nnx.metrics.Average('loss'),\n",
")\n",
"\n",
"@nnx.jit # <== currently slow\n",
"def train_step(model, optimizer, metrics, x, y):\n",
" def loss_fn(model):\n",
" y_pred = model(x) # call methods directly\n",
" return ((y_pred - y) ** 2).mean()\n",
"\n",
" loss, grads = nnx.value_and_grad(loss_fn)(model)\n",
" optimizer.update(grads) # in-place updates\n",
" metrics.update(loss=loss)\n",
"\n",
" return loss\n",
" \n",
"for _ in range(10):\n",
" x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n",
" loss = train_step(model, optimizer, metrics, x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To speed it up, before starting the training loop we can use `nnx.split` over the all the Flax NNX objects that are inputs to `train_step` to create a `graphdef` and `state` pytrees that are fast to traverse. Next we change `train_step` so accept `graphdef` and `state` and use `nnx.merge` and `nnx.split` at the beginning and end of `train_step` to switch back and forth between the objects and their pytree representations. Even though `nnx.split` and `nnx.merge` are slow it doesn't matter because they will only run once during tracing. With this in place, we can change the `train_step` function to use `jax.jit` instead of `nnx.jit`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization\n",
"optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) # reference sharing\n",
"metrics = nnx.MultiMetric(\n",
" loss=nnx.metrics.Average('loss'),\n",
")\n",
"# split before training loop\n",
"graphdef, state = nnx.split((model, optimizer, metrics))\n",
"\n",
"@jax.jit # regular JAX\n",
"def train_step(graphdef, state, x, y):\n",
" # merge at the beginning of the function\n",
" model, optimizer, metrics = nnx.merge(graphdef, state)\n",
"\n",
" def loss_fn(model):\n",
" y_pred = model(x) # call methods directly\n",
" return ((y_pred - y) ** 2).mean()\n",
"\n",
" loss, grads = nnx.value_and_grad(loss_fn)(model)\n",
" optimizer.update(grads)\n",
" metrics.update(loss=loss)\n",
"\n",
" # split at the end of the function\n",
" _, state = nnx.split((model, optimizer, metrics))\n",
"\n",
" # return new state\n",
" return state, loss\n",
"\n",
"for _ in range(10):\n",
" x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n",
" state, loss = train_step(graphdef, state, x, y)\n",
"\n",
"# update objects after training\n",
"nnx.update((model, optimizer, metrics), state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that we only do this for `jit`, you can still use other transforms like `nnx.value_and_grad` shown in the example since their overhead is already absorbed by the outer `jit`. Also, after the training loop is done (or whenever need) `nnx.update` can be used to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`."
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"language_info": {
"name": "python",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
110 changes: 110 additions & 0 deletions docs_nnx/guides/performance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.13.8
---

# Performance Considerations
Currently `nnx.jit` traverses the object graph in pure Python, this is slow and adds overhead. To solve this in general we will be developing a Rust extension called `flaxlib` (see first steps in #4196) to speedup some of the traversal logic in `graph.py`, similar to how JAX solved the same issue with `jaxlib` for standard pytrees. However, there's two things to consider:

* The overhead is only relevant for small models. See [Asynchronous dispatch](#asynchronous-dispatch).
* You can remove the overhead by using `jax.jit` + `nnx.split` / `nnx.merge` to stage out the traversal logic. See [Lowering the Python Overhead](#lowering-the-python-overhead).


## Asynchronous dispatch
In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with `nnx.jit` and `jax.jit`. As you can see in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`.

![performance-graph](images/performance-graph.png)

This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead.

## Lowering the Python Overhead
To remove the python overhead you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. To learn how to do this, lets first create this simple model:

```{code-cell}
from flax import nnx
import jax
import jax.numpy as jnp
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
```

Lets say we have this `train_step` function that is using `nnx.jit` and takes in a `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:

```{code-cell}
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average('loss'),
)
@nnx.jit # <== currently slow
def train_step(model, optimizer, metrics, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # in-place updates
metrics.update(loss=loss)
return loss
for _ in range(10):
x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
loss = train_step(model, optimizer, metrics, x, y)
```

To speed it up, before starting the training loop we can use `nnx.split` over the all the Flax NNX objects that are inputs to `train_step` to create a `graphdef` and `state` pytrees that are fast to traverse. Next we change `train_step` so accept `graphdef` and `state` and use `nnx.merge` and `nnx.split` at the beginning and end of `train_step` to switch back and forth between the objects and their pytree representations. Even though `nnx.split` and `nnx.merge` are slow it doesn't matter because they will only run once during tracing. With this in place, we can change the `train_step` function to use `jax.jit` instead of `nnx.jit`:

```{code-cell}
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) # reference sharing
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average('loss'),
)
# split before training loop
graphdef, state = nnx.split((model, optimizer, metrics))
@jax.jit # regular JAX
def train_step(graphdef, state, x, y):
# merge at the beginning of the function
model, optimizer, metrics = nnx.merge(graphdef, state)
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
metrics.update(loss=loss)
# split at the end of the function
_, state = nnx.split((model, optimizer, metrics))
# return new state
return state, loss
for _ in range(10):
x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
state, loss = train_step(graphdef, state, x, y)
# update objects after training
nnx.update((model, optimizer, metrics), state)
```

Notice that we only do this for `jit`, you can still use other transforms like `nnx.value_and_grad` shown in the example since their overhead is already absorbed by the outer `jit`. Also, after the training loop is done (or whenever need) `nnx.update` can be used to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`.
Loading

0 comments on commit abc1155

Please sign in to comment.