Skip to content

Commit

Permalink
Merge pull request #4262 from 8bitmp3:fixes-nnx-1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684237478
  • Loading branch information
Flax Authors committed Oct 10, 2024
2 parents 8a4d6d6 + 50df809 commit 53ba626
Showing 1 changed file with 100 additions and 82 deletions.
182 changes: 100 additions & 82 deletions docs_nnx/why.rst
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
Why NNX?
========
Why Flax NNX?
=============

Years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling
and performance. We've learned a lot from our users over these years. We introduced some ideas that have proven to be good:
In 2020, the Flax team released the Flax Linen API to support modeling research on JAX, with a focus on scaling
and performance. We have learned a lot from users since then. The team introduced certain ideas that have proven to be beneficial to users, such as:

* Organizing variables into `collections <https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections>`_.
* Automatic and efficient `PRNG management <https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences>`_.
* `Variable Metadata <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning>`_
for SPMD annotations, optimizer metadata, etc.
* Automatic and efficient `pseudorandom number generator (PRNG) management <https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences>`_.
* `Variable metadata <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning>`_
for `Single Program Multi Data (SPMD) <https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD>`_ annotations, optimizer metadata, and other use cases.

One choice we made was to use functional (``compact``) semantics for NN programming via the lazy initialization of parameters,
this made for concise implementation code and aligned our API with Haiku. However, this also meant that the semantics of
modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured
the core ideas of transformations on neural nets.
One of the choices the Flax team made was to use functional (``compact``) semantics for neural network programming via lazy initialization of parameters.
This made for concise implementation code and aligned the Flax Linen API with Haiku.

However, this also meant that the semantics of Modules and variables in Flax were non-Pythonic and often surprising. It also led to implementation
complexity and obscured the core ideas of `transformations (transforms) <https://jax.readthedocs.io/en/latest/glossary.html#term-transformation>`_ on neural networks.

.. testsetup:: Linen, NNX

Expand All @@ -23,15 +24,16 @@ the core ideas of transformations on neural nets.

Introducing Flax NNX
--------------------
Flax NNX is an attempt to keep the features that made Linen useful while introducing some new principles.
The central idea behind Flax NNX is to introduce reference semantics into JAX. These are its main features:

- **Pythonic**: supports regular Python semantics for Modules, including for mutability and shared references.
- **Simple**: many of the complex APIs in Flax Linen are either simplified using Python idioms or removed entirely.
- **Better JAX integration**: both by making custom transforms adopt the same APIs as JAX transforms, and by making
it easier to use JAX transforms directly.
Fast forward to 2024, the Flax team developed Flax NNX - an attempt to retain the features that made Flax Linen useful for users, while introducing some new principles.
The central idea behind Flax NNX is to introduce reference semantics into JAX. The following are its main features:

- **NNX is Pythonic**: Regular Python semantics for Modules, including support for mutability and shared references.
- **NNX is simple**: Many of the complex APIs in Flax Linen are either simplified using Python idioms or completely removed.
- **Better JAX integration**: Custom NNX transforms adopt the same APIs as the JAX transforms. And with NNX
it is easier to use `JAX transforms (higher-order functions) <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`_ directly.

Here's an example of a simple Flax NNX program that illustrates many of the points above:
Here is an example of a simple Flax NNX program that illustrates many of the points from above:

.. testcode:: NNX

Expand All @@ -50,10 +52,10 @@ Here's an example of a simple Flax NNX program that illustrates many of the poin
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # Eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # Reference sharing.

@nnx.jit # automatic state management for JAX transforms
@nnx.jit # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
Expand All @@ -64,15 +66,18 @@ Here's an example of a simple Flax NNX program that illustrates many of the poin

return loss

Improvements
------------
Through the rest of this document, we'll key examples of how Flax NNX improves on Flax Linen.
Flax NNX's improvements on Linen
--------------------------------

The rest of this document uses various examples that demonstrate how Flax NNX improves on Flax Linen.

Inspection
^^^^^^^^^^
The first improvement is that Flax NNX modules are regular Python objects, so you can easily
construct and inspect them. Because Flax Linen Modules are lazy, some attributes are not available
upon construction and are only accesible at runtime. This makes it hard to inspect and debug.

The first improvement is that Flax NNX Modules are regular Python objects. This means that you can easily
construct and inspect ``Module`` objects.

On the other hand, Flax Linen Modules are not easy to inspect and debug because they are lazy, which means some attributes are not available upon construction and are only accessible at runtime.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -114,16 +119,17 @@ upon construction and are only accesible at runtime. This makes it hard to inspe
# ),
# ...

Notice that in Flax NNX there is no shape inference so both the input and output shapes must be provided
to the Linear module. This is a tradeoff that allows for more explicit and predictable behavior.
Notice that in the Flax NNX example above, there is no shape inference - both the input and output shapes must be provided
to the ``Linear`` ``nnx.Module``. This is a tradeoff that allows for more explicit and predictable behavior.

Running Computation
Running computation
^^^^^^^^^^^^^^^^^^^
In Flax Linen, all top-level computation must be done through the ``init`` or ``apply`` methods and the
parameters or any other type of state is handled as a separate structure. This creates an asymmetry
between code that runs inside ``apply`` that can run methods and other Modules directly, and code
outside of ``apply`` that must use the ``apply`` method. In Flax NNX, there's no special context
as parameters are held as attributes and methods can be called directly.

In Flax Linen, all top-level computation must be done through the ``flax.linen.Module.init`` or ``flax.linen.Module.apply`` methods, and the
parameters or any other type of state are handled as a separate structure. This creates an asymmetry between: 1) code that runs inside
``apply`` that can run methods and other ``Module`` objects directly; and 2) code that runs outside of ``apply`` that must use the ``apply`` method.

In Flax NNX, there's no special context because parameters are held as attributes and methods can be called directly. That means your NNX Module's ``__init__`` and ``__call__`` methods are not treated differently from other class methods, whereas Flax Linen Module's ``setup()`` and ``__call__`` methods are special.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -175,16 +181,19 @@ as parameters are held as attributes and methods can be called directly.
z = model.encode(x)
y = model.decoder(z)

Note that in Linen, calling submodules directly is not possible as they are not initialized.
So you must construct a new instance and provide proper parameter structure. In NNX
you can call submodules directly without any issues.
In Flax Linen, calling sub-Modules directly is not possible because they are not initialized.
Therefore, what you must do is construct a new instance and then provide a proper parameter structure.

But in Flax NNX you can call sub-Modules directly without any issues.

State Handling
State handling
^^^^^^^^^^^^^^
One of the areas where Flax Linen is notoriously complex is in handling state. When you either use a
Dropout layer or a BatchNorm layer, or both, you suddenly have to handle the new state and use it to
configure the ``apply`` method. In Flax NNX, state is kept inside the Module and is mutable, so it can
just be called directly.

One of the areas where Flax Linen is notoriously complex is in state handling. When you use either a
`Dropout` layer, a `BatchNorm` layer, or both, you suddenly have to handle the new state and use it to
configure the ``flax.linen.Module.apply`` method.

In Flax NNX, state is kept inside an ``nnx.Module`` and is mutable, which means it can just be called directly.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -240,9 +249,10 @@ just be called directly.

...

The main benefit is that this usually means you don't have to change the training code when you add
a new stateful layers. Layers that handle state are also very easy to implement in Flax NNX, below
is a simplified version of a BatchNorm layer that updates the mean and variance every time it's called.
The main benefit of Flax NNX's state handling is that you don't have to change the training code when you add a new stateful layer.

In addition, in Flax NNX, layers that handle state are also very easy to implement. Below
is a simplified version of a ``BatchNorm`` layer that updates the mean and variance every time it is called.

.. testcode:: NNX

Expand All @@ -253,7 +263,7 @@ is a simplified version of a BatchNorm layer that updates the mean and variance
self.bias = nnx.Param(jax.numpy.zeros((features,)))
self.mean = nnx.BatchStat(jax.numpy.zeros((features,)))
self.var = nnx.BatchStat(jax.numpy.ones((features,)))
self.mu = mu # static
self.mu = mu # Static

def __call__(self, x):
mean = jax.numpy.mean(x, axis=-1)
Expand All @@ -266,16 +276,17 @@ is a simplified version of a BatchNorm layer that updates the mean and variance
return x * self.scale + self.bias


Surgery
^^^^^^^
Model surgery historically has been a difficult problem in Flax Linen because of two reasons:
1. Due to lazy initialization, its not guaranteed you can replace a submodule with new one.
2. The parameter structure is separate from the module structure, so you manually have to keep
them in sync.
Model surgery
^^^^^^^^^^^^^

In Flax NNX, you can replace submodules directly per Python semantics. Since the parameters are
part of the Module structre, they are never out of sync. Below is an example of how you can
implement a LoRA layer and replace a Linear layer of an existing model with it.
In Flax Linen, model surgery has historically been challenging because of two reasons:

1. Due to lazy initialization, it is not guaranteed that you can replace a sub-``Module`` with a new one.
2. The parameter structure is separated from the ``flax.linen.Module`` structure, which means you have to manually keep them in sync.

In Flax NNX, you can replace sub-Modules directly as per the Python semantics. Since parameters are
part of the ``nnx.Module`` structure, they are never out of sync. Below is an example of how you can
implement a LoRA layer, and then use it to replace a ``Linear`` layer in an existing model.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -327,14 +338,14 @@ implement a LoRA layer and replace a Linear layer of an existing model with it.

...

As should above, in Linen this doesn't really work in this case because the ``.linear`` submodule
is not available, however the rest of the code gives an idea how the ``params`` structure must be
manually updated.
As shown above, in Flax Linen this doesn't really work in this case because the ``linear`` sub-``Module``
is not available. However, the rest of the code provides an idea of how the ``params`` structure must be manually updated.

Performing arbitrary model surgery is not very easy in Flax Linen, currently the
Performing arbitrary model surgery is not easy in Flax Linen, and currently the
`intercept_methods <https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.intercept_methods>`_
API is the only was to do generic patching of methods but it's not very ergonomic. In NNX, using ``iter_graph`` its very easy
to do generic model surgery, below is an example of replacing all Linear layers in a model with LoRA layers.
API is the only way to do generic patching of methods. But this API is not very ergonomic.

In Flax NNX, to do generic model surgery you can just use ``nnx.iter_graph``, which is much simpler and easier than in Linen. Below is an example of replacing all ``nnx.Linear`` layers in a model with custom-made ``LoraLinear`` NNX layers.

.. testcode:: NNX

Expand All @@ -349,22 +360,27 @@ to do generic model surgery, below is an example of replacing all Linear layers

Transforms
^^^^^^^^^^
Flax Linen transforms are very powerful in that they allow fine-grained control over the model's state,
however Linen transforms have the following drawbacks:
1. They expose additional APIs that are not part of JAX.
2. They work on functions with very specific signatures:
* A Module must be the first argument.
* They accepts other Modules as arguments but not as return values.
3. They can only be used inside ``apply``.

`Flax NNX transforms <https://flax.readthedocs.io/en/latest/guides/transforms.html>`_ on the other hand
are intented to be equivalent to JAX transforms with the exception that they can be used on Modules. This
means they have the same API as JAX transforms, can accepts Modules on any argument and Modules can be
returned from them, and they can be used anywhere including the training loop.

Here is an example of using ``vmap`` with Flax NNX to both create a stack of weights by transforming the
``create_weights`` function which returns some ``Weights``, and to apply the stack of weights to a batch
of inputs individually by transforming the ``vector_dot`` function which takes a ``Weights`` as the first

Flax Linen transforms are very powerful in that they enable fine-grained control over the model's state.
However, Flax Linen transforms have drawbacks, such as:

1. They expose additional APIs that are not part of JAX, making their behavior confusing and sometimes divergent from their JAX counterparts. This also constrains your ways to interact with `JAX transforms <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`_ and keep up with JAX API changes.
2. They work on functions with very specific signatures, namely:
- A ``flax.linen.Module`` must be the first argument.
- They accept other ``Module`` objects as arguments but not as return values.
3. They can only be used inside ``flax.linen.Module.apply``.

On the other hand, `Flax NNX transforms <https://flax.readthedocs.io/en/latest/guides/transforms.html>`_
are intented to be equivalent to their corresponding `JAX transforms <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`_
with an exception - they can be used on Flax NNX Modules. This means that Flax transforms:

1) Have the same API as JAX transforms.
2) Can accept Flax NNX Modules on any argument, and ``nnx.Module`` objects can be returned from it/them.
3) Can be used anywhere including the training loop.

Below is an example of using ``vmap`` with Flax NNX to both create a stack of weights by transforming the
``create_weights`` function, which returns some ``Weights``, and to apply that stack of weights to a batch
of inputs individually by transforming the ``vector_dot`` function, which takes ``Weights`` as the first
argument and a batch of inputs as the second argument.

.. testcode:: NNX
Expand All @@ -390,10 +406,10 @@ argument and a batch of inputs as the second argument.
x = jax.random.normal(random.key(1), (10, 2))
y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x)

Contrary to Linen transforms, the arguments ``in_axes`` and other APIs do affect how the Module state is transformed.
Contrary to Flax Linen transforms, the ``in_axes`` argument and other APIs do affect how the ``nnx.Module`` state is transformed.

Flax NNX transforms can also be used as method decorators, as Module methods are simply
functions that take a Module as the first argument. This means that the previous example can be
In addition, Flax NNX transforms can be used as method decorators, because ``nnx.Module`` methods are simply
functions that take a ``Module`` as the first argument. This means that the previous example can be
rewritten as follows:

.. testcode:: NNX
Expand All @@ -413,4 +429,6 @@ rewritten as follows:
weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)
y = weights(x)


0 comments on commit 53ba626

Please sign in to comment.