Skip to content

Commit

Permalink
Merge pull request #4200 from 8bitmp3:upgrade-haiku-linen-to-nnx
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693412381
  • Loading branch information
Flax Authors committed Nov 5, 2024
2 parents d8b1a92 + 84fa316 commit 0679702
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Migrating from Haiku to Flax

This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku.

To get the most out of this guide, it is highly recommended to go through `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__ document, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.
If you are new to Flax NNX, make sure you become familiarized with `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.

Let’s start with some imports.

.. testsetup:: Haiku, Flax NNX

Expand All @@ -14,16 +15,25 @@ To get the most out of this guide, it is highly recommended to go through `Flax
from typing import Any


Basic Module Definition
Basic Module definition
=======================

Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer.
Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. For example, to create a one-layer network with dropout and a ReLU activation function, you:

* First, create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function.
* Then, use ``Block`` as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer.

There are two fundamental differences between Haiku and Flax ``Module`` objects:

* **Stateless vs. stateful**: A ``hk.Module`` instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object.
* **Stateless vs. stateful**:

* A ``haiku.Module`` instance is stateless. This means, the variables are returned from a purely functional ``Module.init()`` call and managed separately.
* A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object.

* **Lazy vs. eager**:

* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).
* A ``haiku.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy).
* A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).


.. codediff::
Expand Down Expand Up @@ -82,13 +92,13 @@ There are two fundamental differences between Haiku and Flax ``Module`` objects:


Variable creation
=======================
=================

Next, let's discuss instantiating the model and initializing its parameters
This section is about instantiating a model and initializing its parameters.

* To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``hk.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays <https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array>`__ (``jax.Array`` data types) to be carried around and maintained separately.
* To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``haiku.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays <https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array>`__ (``jax.Array`` data types) to be carried around and maintained separately.

* In Flax, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable<flax.nnx.Variable>` objects) are stored inside the :class:`nnx.Module<flax.nnx.Module>` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) <https://jax.readthedocs.io/en/latest/random-numbers.html>`__ key, but that key will be wrapped inside an :class:`nnx.Rngs<flax.nnx.Rngs>` class and stored inside, generating more PRNG keys when needed.
* In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable<flax.nnx.Variable>` objects) are stored inside the :class:`nnx.Module<flax.nnx.Module>` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) <https://jax.readthedocs.io/en/latest/random-numbers.html>`__ key, but that key will be wrapped inside an :class:`nnx.Rngs<flax.nnx.Rngs>` class and stored inside, generating more PRNG keys when needed.

If you want to access Flax model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `Flax NNX split/merge API <https://flax.readthedocs.io/en/latest/nnx_basics.html#state-and-graphdef>`__ (:func:`nnx.split<flax.nnx.split>` / :func:`nnx.merge<flax.nnx.merge>`).

Expand Down Expand Up @@ -122,35 +132,34 @@ If you want to access Flax model parameters in the stateless, dictionary-like fa
assert model.block.linear.kernel.value.shape == (784, 256)

Training step and compilation
=======================
=============================

This section covers writing a training step and compiling it using the `JAX just-in-time compilation <https://jax.readthedocs.io/en/latest/jit-compilation.html>`__.

Now, let's proceed to writing a training step and compiling it using `JAX just-in-time compilation <https://jax.readthedocs.io/en/latest/jit-compilation.html>`__. Below are certain differences between Haiku and Flax NNX approaches.
When compiling the training step:

Compiling the training step:
* Haiku uses ``@jax.jit`` - a `JAX transformation <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ - to compile a purely functional training step.
* Flax NNX uses :meth:`@nnx.jit<flax.nnx.jit>` - a `Flax NNX transformation <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__). While ``jax.jit`` only accepts functions with pure stateless arguments, ``flax.nnx.jit`` allows the arguments to be stateful Modules. This greatly reduces the number of lines needed for a train step.

* Haiku uses ``@jax.jit`` - a `JAX transform <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ - to compile a purely functional training step.
* Flax NNX uses :meth:`@nnx.jit<flax.nnx.jit>` - a `Flax NNX transform <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__). So, while ``jax.jit`` only accepts functions with pure stateless arguments, ``nnx.jit`` allows the arguments to be stateful Modules. This greatly reduced the number of lines needed for a train step.
When taking gradients:

Taking gradients:
* Similarly, Haiku uses ``jax.grad`` (a JAX transformation for `automatic differentiation <https://jax.readthedocs.io/en/latest/automatic-differentiation.html#taking-gradients-with-jax-grad>`__) to return a raw dictionary of gradients.
* Meanwhile, Flax NNX uses :meth:`flax.nnx.grad<flax.nnx.grad>` (a Flax NNX transformation) to return the gradients of Flax NNX Modules as :class:`flax.nnx.State<flax.nnx.State>` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX, you need to use the `split/merge API <https://flax.readthedocs.io/en/latest/nnx_basics.html#state-and-graphdef>`__.

* Similarly, Haiku uses ``jax.grad`` (a JAX transform for `automatic differentiation <https://jax.readthedocs.io/en/latest/automatic-differentiation.html#taking-gradients-with-jax-grad>`__) to return a raw dictionary of gradients.
* Flax NNX uses :meth:`nnx.grad<flax.nnx.grad>` (a Flax NNX transform) to return the gradients of NNX Modules as :class:`nnx.State<flax.nnx.State>` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX you need to use the `split/merge API <https://flax.readthedocs.io/en/latest/nnx_basics.html#state-and-graphdef>`__.
For optimizers:

Optimizers:

* If you are already using `Optax <https://optax.readthedocs.io/>`__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`nnx.Optimizer<flax.nnx.Optimizer>` example in the `Flax basics <https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ guide for a much more concise way of training and updating your model.
* If you are already using `Optax <https://optax.readthedocs.io/>`__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`flax.nnx.Optimizer<flax.nnx.Optimizer>` example in the `Flax basics <https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ guide for a much more concise way of training and updating your model.

Model updates during each training step:

* The Haiku training step needs to return a `pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of parameters as the input of the next step.
* The Flax training step doesn't need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit<flax.nnx.jit>`.
* In addition, :class:`nnx.Module<flax.nnx.Module>` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``BatchNorm`` stats. That is why you don't need to explicitly pass an PRNG key in on every step. Also note that you can use :meth:`nnx.reseed<flax.nnx.reseed>` to reset its underlying PRNG state.
* The Haiku training step needs to return a `JAX pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of parameters as the input of the next step.
* The Flax NNX training step does not need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit<flax.nnx.jit>`.
* In addition, :class:`nnx.Module<flax.nnx.Module>` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``flax.nnx.BatchNorm`` stats. That is why you don't need to explicitly pass a PRNG key in at every step. Also note that you can use :meth:`flax.nnx.reseed<flax.nnx.reseed>` to reset its underlying PRNG state.

Dropout behavior:
The dropout behavior:

* In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``hk.dropout`` and make sure random dropout only happens if ``training=True``.
* In Flax, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`nnx.Dropout<flax.nnx.Dropout>` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``nnx.Module.train`` does in its `API reference <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module.train>`__.
* In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``haiku.dropout`` and make sure that random dropout only happens if ``training=True``.
* In Flax NNX, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`flax.nnx.Dropout<flax.nnx.Dropout>` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``flax.nnx.Module.train`` does in its `API reference <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module.train>`__.

.. codediff::
:title: Haiku, Flax NNX
Expand Down Expand Up @@ -208,7 +217,7 @@ Dropout behavior:


Handling non-parameter states
======================
=============================

Haiku makes a distinction between trainable parameters and all other data ("states") that the model tracks. For example, the batch stats used in batch norm is considered a state. Models with states needs to be transformed with ``hk.transform_with_state`` so that their ``.init()`` returns both params and states.

Expand Down Expand Up @@ -693,3 +702,5 @@ be set and accessed as normal using regular Python class semantics.
_, params, counter = nnx.split(model, nnx.Param, Counter)




0 comments on commit 0679702

Please sign in to comment.