From 899a8e7c82f2c39bb9c339449c36f7c7734dd9c5 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Sun, 6 Oct 2024 20:43:32 +0000 Subject: [PATCH] Upgrade Migrating to Flax NNX guide --- docs_nnx/guides/haiku_to_flax.rst | 343 +++++++++++------------------- 1 file changed, 120 insertions(+), 223 deletions(-) diff --git a/docs_nnx/guides/haiku_to_flax.rst b/docs_nnx/guides/haiku_to_flax.rst index 16dcd86570..73feea9e33 100644 --- a/docs_nnx/guides/haiku_to_flax.rst +++ b/docs_nnx/guides/haiku_to_flax.rst @@ -1,11 +1,17 @@ - Migrating from Haiku to Flax -================================= +============================ + +This guide compares and contrasts Haiku with Flax Linen and Flax NNX, which should help you migrate to Flax from Haiku. You will learn the differences between the frameworks through various examples, such as: +- Simple model creation with ``Module`` and dropout, model instantiation and parameter initialization, and setting up the ``train_step`` for training. +- Handling mutable states (using ``BatchNorm`` instead of dropout from model creation to training). +- Using multiple methods (using an auto-encoder model). +- Lifted transformations (using ``scan`` and a recurrent neural network). +- ``Scan`` over layers. +- Top-level Haiku functions vs top-level Flax ``Module``s. + +Both Haiku and Flax Linen enforce a functional paradigm with stateless ``Module``, while Flax NNX embraces the Python language to provide a more intuitive development experience. -This guide will showcase the differences between Haiku, Flax Linen and Flax NNX. -Both Haiku and Linen enforce a functional paradigm with stateless modules, -while NNX is a new, next-generation API that embraces the python language to -provide a more intuitive development experience. +First, some necessary imports: .. testsetup:: Haiku, Linen, NNX @@ -21,18 +27,18 @@ provide a more intuitive development experience. # TODO: make sure code lines are not too long # TODO: make sure all code diffs are aligned -Basic Example ------------------ +Basic example +------------- -To create custom Modules you subclass from a ``Module`` base class in -both Haiku and Flax. Modules can be defined inline in Haiku and Flax -Linen (using the ``@nn.compact`` decorator), whereas modules can't be -defined inline in NNX and must be defined in ``__init__``. +Let’s begin with a basic example of a one-layer network with dropout and a ReLU activation. To create a custom ``Module``, in Haiku and Flax (Linen and NNX) you subclass the ``Module`` base class. -Linen requires a ``deterministic`` argument to control whether or -not dropout is used. NNX also uses a ``deterministic`` argument -but the value can be set later using ``.eval()`` and ``.train()`` methods -that will be shown in a later code snippet. +Note that: +- In Haiku and Flax Linen, ``Module``s can be defined inline using the ``@nn.compact`` decorator). +- In Flax NNX, ``nnx.Module``s can't be defined inline in NNX but instead must be defined in ``__init__``. + +In addition: +- Flax Linen requires a ``deterministic`` argument to control whether or not dropout is used. +- Flax NNX also uses a ``deterministic`` argument but the value can be set later using ``.eval()`` and ``.train()`` methods that will be shown in a later code snippet. .. codediff:: :title: Haiku, Linen, NNX @@ -114,11 +120,9 @@ that will be shown in a later code snippet. x = self.linear(x) return x -Since modules are defined inline in Haiku and Linen, the parameters -are lazily initialized, by inferring the shape of a sample input. In Flax -NNX, the module is stateful and is initialized eagerly. This means that the -input shape must be explicitly passed during module instantiation since there -is no shape inference in NNX. +Initializing the parameters: +- In Haiku and Flax Linen, since ``Module``s are defined inline, the parameters are lazily initialized by inferring the shape of a sample input. +- In Flax NNX, the ``nnx.Module`` is stateful and is initialized eagerly. This means that the input shape must be explicitly passed during ``nnx.Module`` instantiation since there is no shape inference in NNX. .. codediff:: :title: Haiku, Linen, NNX @@ -143,18 +147,13 @@ is no shape inference in NNX. model = Model(784, 256, 10, rngs=nnx.Rngs(0)) -To get the model parameters in both Haiku and Linen, you use the ``init`` method -with a ``random.key`` plus some inputs to run the model. - -In NNX, the model parameters are automatically initialized when the user -instantiates the model because the input shapes are already explicitly passed at -instantiation time. +To get the model parameters: +- In both Haiku and Flax Linen, use the ``init`` method with a ``random.key`` plus some inputs to run the model. +- In Flax NNX, the model parameters are automatically initialized when you instantiate the model because the input shapes are already explicitly passed at instantiation time. -Since NNX is eager and the module is bound upon instantiation, the user can access -the parameters (and other fields defined in ``__init__`` via dot-access). On the other -hand, Haiku and Linen use lazy initialization and so the parameters can only be accessed -once the module is initialized with a sample input and both frameworks do not support -dot-access of their attributes. +Also: +- Since Flax NNX is eager and the ``Module`` is bound upon instantiation, you can access the parameters (and other fields defined in ``__init__`` via dot-access). +- On the other hand, Haiku and Flax Linen use lazy initialization. Therefore the parameters can only be accessed once the ``Module`` is initialized with a sample input, and both frameworks do not support dot-access of their attributes. .. codediff:: :title: Haiku, Linen, NNX @@ -188,21 +187,14 @@ dot-access of their attributes. - # parameters were already initialized during model instantiation + # Parameters were already initialized during model instantiation. assert model.linear.bias.value.shape == (10,) assert model.block.linear.kernel.value.shape == (784, 256) -Let's take a look at the parameter structure. In Haiku and Linen, we can -simply inspect the ``params`` object returned from ``.init()``. - -To see the parameter structure in NNX, the user can call ``nnx.split`` to -generate ``Graphdef`` and ``State`` objects. The ``Graphdef`` is a static pytree -denoting the structure of the model (for example usages, see -`NNX Basics `__). -``State`` objects contains all the module variables (i.e. any class that sub-classes -``nnx.Variable``). If we filter for ``nnx.Param``, we will generate a ``State`` object -of all the learnable module parameters. +Let's review the parameter structure: +- In Haiku and Flax Linen, you would simply inspect the ``params`` object returned from ``.init()``. +- In Flax NNX, to view the parameter structure, you can call ``nnx.split`` to generate ``nnx.Graphdef`` and ``nnx.State`` objects. The ``nnx.Graphdef`` is a static pytree denoting the structure of the model (for examples, check out the `Flax basics`__). ``nnx.State`` objects contain all ``Module`` variables (i.e. any class that subclasses ``nnx.Variable``). If you filter for ``nnx.Param``, you will generate a ``nnx.State`` object of all the learnable ``Module`` parameters (you can learn more in `Using Filters`__. .. tab-set:: @@ -271,22 +263,13 @@ of all the learnable module parameters. } }) -During training in Haiku and Linen, you pass the parameters structure to the -``apply`` method to run the forward pass. To use dropout, we must pass in -``training=True`` and provide a ``key`` to ``apply`` in order to generate the -random dropout masks. To use dropout in NNX, we first call ``model.train()``, -which will set the dropout layer's ``deterministic`` attribute to ``False`` -(conversely, calling ``model.eval()`` would set ``deterministic`` to ``True``). -Since the stateful NNX module already contains both the parameters and RNG key -(used for dropout), we simply need to call the module to run the forward pass. We -use ``nnx.split`` to extract the learnable parameters (all learnable parameters -subclass the NNX class ``nnx.Param``) and then apply the gradients and statefully -update the model using ``nnx.update``. - -To compile ``train_step``, we decorate the function using ``@jax.jit`` for Haiku -and Linen, and ``@nnx.jit`` for NNX. Similar to ``@jax.jit``, ``@nnx.jit`` also -compiles functions, with the additional feature of allowing the user to compile -functions that take in NNX modules as arguments. +During training: +- in Haiku and Flax Linen, you pass the parameters structure to the ``apply`` method to run the forward pass. To use dropout, you must pass in ``training=True`` and provide a ``key`` to ``apply`` to generate the random dropout masks. To use dropout in NNX, you first call ``model.train()``, which will set the dropout layer's ``deterministic`` attribute to ``False`` (conversely, calling ``model.eval()`` would set ``deterministic`` to ``True``). +- In Flax NNX, since the stateful `nnx.Module` already contains both the parameters and the PRNG key (used for dropout), you simply need to call the ``nnx.Module`` to run the forward pass. Use ``nnx.split`` to extract the learnable parameters (all learnable parameters subclass the ``nnx.Param`` class), and then apply the gradients and statefully update the model using ``nnx.update``. + +To compile the ``train_step``: +- Haiku and Flax Linen, you decorate the function using ``@jax.jit``. +- In Flax NNX, you decorate it with ``@nnx.jit``. Similar to ``@jax.jit``, ``@nnx.jit`` also compiles functions, with the additional feature of allowing the user to compile functions that take in NNX modules as arguments. .. codediff:: :title: Haiku, Linen, NNX @@ -347,7 +330,7 @@ functions that take in NNX modules as arguments. return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = nnx.grad(loss_fn)(model) - # we can use Ellipsis to filter out the rest of the variables + # You can use Ellipsis to filter out the rest of the variables. _, params, _ = nnx.split(model, nnx.Param, ...) params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) @@ -364,23 +347,13 @@ functions that take in NNX modules as arguments. sample_x = jnp.ones((1, 784)) train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) -Flax also offers a convenient ``TrainState`` dataclass to bundle the model, -parameters and optimizer, to simplify training and updating the model. In Haiku -and Linen, we simply pass in the ``model.apply`` function, initialized parameters -and optimizer as arguments to the ``TrainState`` constructor. - -In NNX, we must first call ``nnx.split`` on the model to get the -separated ``GraphDef`` and ``State`` objects. We can pass in ``nnx.Param`` to filter -all trainable parameters into a single ``State``, and pass in ``...`` for the remaining -variables. We also need to subclass ``TrainState`` to add a field for the other variables. -We can then pass in ``GraphDef.apply`` as the apply function, ``State`` as the parameters -and other variables and an optimizer as arguments to the ``TrainState`` constructor. -One thing to note is that ``GraphDef.apply`` will take in ``State``'s as arguments and -return a callable function. This function can be called on the inputs to output the -model's logits, as well as updated ``GraphDef`` and ``State`` objects. This isn't needed -for our current example with dropout, but in the next section, you will see that using -these updated objects are relevant with layers like batch norm. Notice we also use -``@jax.jit`` since we aren't passing in NNX modules into ``train_step``. +In addition, Flax NNX offers a convenient ``nnx.TrainState`` dataclass to bundle the model, parameters and the optimizer to simplify training and updating the model (you can learn more in `Using TrainState in NNX`__. + +- In Haiku and Flax Linen, you simply pass in the ``model.apply`` function, initialized parameters and optimizer as arguments to the ``TrainState`` constructor. +- In Flax NNX, you must first call ``nnx.split`` on the model to get the separated ``nnx.GraphDef`` and ``nnx.State`` objects.You can pass in ``nnx.Param`` to filter +all trainable parameters into a single ``nnx.State``, and pass in ``...`` for the remaining variables. You also need to subclass ``nnx.TrainState`` to add a field for the other variables. Then, you can pass in ``nnx.GraphDef.apply`` as the apply function, ``nnx.State`` as the parameters and other variables and an optimizer as arguments to the ``nnx.TrainState`` constructor. + +**Note:** ``nnx.GraphDef.apply`` will take in ``nnx.State``s as arguments and return a callable function. This function can be called on the inputs to output the model's logits, as well as updated ``nnx.GraphDef`` and ``nnx.State`` objects. This isn't needed for our current example with dropout, but in the next section, you will see that using these updated objects are relevant with layers like batch normalization. Notice the use of ``@jax.jit`` since you aren't passing in NNX modules into ``train_step``. .. codediff:: :title: Haiku, Linen, NNX @@ -496,11 +469,10 @@ these updated objects are relevant with layers like batch norm. Notice we also u train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) -Handling State ------------------ +Handling ``State`` +------------------ -Now let's see how mutable state is handled in all three frameworks. We will take -the same model as before, but now we will replace Dropout with BatchNorm. +Now let's review how mutable state is handled in all three frameworks. You will use the same model as before, but this time you will replace ``Dropout`` with ``BatchNorm``: .. codediff:: :title: Haiku, Linen, NNX @@ -555,14 +527,9 @@ the same model as before, but now we will replace Dropout with BatchNorm. x = jax.nn.relu(x) return x -Haiku requires an ``is_training`` argument and Linen requires a -``use_running_average`` argument to control whether or not to update the -running statistics. NNX also uses a ``use_running_average`` argument -but the value can be set later using ``.eval()`` and ``.train()`` methods -that will be shown in later code snippets. - -As before, you need to pass in the input shape to construct the Module -eagerly in NNX. +To control whether or not to update the running statistics: +- Haiku requires the ``is_training`` argument, while Flax Linen requires the ``use_running_average`` argument. +- Flax NNX also uses the ``use_running_average`` argument but the value can be set later using ``.eval()`` and ``.train()`` methods that will be shown in later code snippets. As before, you need to pass in the input shape to construct the ``nnx.Module`` eagerly. .. codediff:: :title: Haiku, Linen, NNX @@ -588,20 +555,12 @@ eagerly in NNX. model = Model(784, 256, 10, rngs=nnx.Rngs(0)) -To initialize both the parameters and state in Haiku and Linen, you just -call the ``init`` method as before. However, in Haiku you now get ``batch_stats`` -as a second return value, and in Linen you get a new ``batch_stats`` collection -in the ``variables`` dictionary. -Note that since ``hk.BatchNorm`` only initializes batch statistics when -``is_training=True``, we must set ``training=True`` when initializing parameters -of a Haiku model with an ``hk.BatchNorm`` layer. In Linen, we can set -``training=False`` as usual. - -In NNX, the parameters and state are already initialized upon module -instantiation. The batch statistics are of class ``nnx.BatchStat`` which -subclasses the ``nnx.Variable`` class (not ``nnx.Param`` since they aren't -learnable parameters). Calling ``nnx.split`` with no additional filter arguments -will return a state containing all ``nnx.Variable``'s by default. +To initialize both the parameters and state: +- In Haiku and Flax Linen, you just call the ``init`` method as before. However: + - In Haiku you now get ``batch_stats`` as a second return value. + - in Linen you get a new ``batch_stats`` collection in the ``variables`` dictionary. + - **Note:** In Haiku, since ``hk.BatchNorm`` only initializes batch statistics when ``is_training=True``, you must set ``training=True`` when initializing parameters of a Haiku model with an ``hk.BatchNorm`` layer. And in Linen, you can set ``training=False`` as usual. +- In Flax NNX, the parameters and state are already initialized upon ```nnx.Module`` instantiation. The batch statistics are of class ``nnx.BatchStat`` which subclasses the ``nnx.Variable`` class (not ``nnx.Param`` since they aren't learnable parameters). Calling ``nnx.split`` with no additional filter arguments will return a state containing all ``nnx.Variable``'s by default. .. codediff:: :title: Haiku, Linen, NNX @@ -633,24 +592,12 @@ will return a state containing all ``nnx.Variable``'s by default. graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat) -Now, training looks very similar in Haiku and Linen as you use the same -``apply`` method to run the forward pass. In Haiku, now pass the ``batch_stats`` -as the second argument to ``apply``, and get the newly updated ``batch_stats`` -as the second return value. In Linen, you instead add ``batch_stats`` as a new -key to the input dictionary, and get the ``updates`` variables dictionary as the -second return value. To update the batch statistics, we must pass in -``training=True`` to ``apply``. - -In NNX, the training code is identical to the earlier example as the -batch statistics (which are bounded to the stateful NNX module) are updated -statefully. To update batch statistics in NNX, we first call ``model.train()``, -which will set the batchnorm layer's ``use_running_average`` attribute to ``False`` -(conversely, calling ``model.eval()`` would set ``use_running_average`` to ``True``). -Since the stateful NNX module already contains the parameters and batch statistics, -we simply need to call the module to run the forward pass. We use ``nnx.split`` to -extract the learnable parameters (all learnable parameters subclass the NNX class -``nnx.Param``) and then apply the gradients and statefully update the model using -``nnx.update``. +Now, onto training: +- In Haiku and Flax Linen, training is very similar, as you use the same ``apply`` method to run the forward pass. + - In Haiku, you pass the ``batch_stats`` as the second argument to ``apply``, and get the newly updated ``batch_stats`` as the second return value. + - In Flax Linen, you add ``batch_stats`` as a new key to the input dictionary, and get the ``updates`` variable dictionary as the second return value. + - To update the batch statistics, you must pass in ``training=True`` to ``apply``. +- In Flax NNX, the training code is identical to the earlier example as the batch statistics (which are bound to the stateful ``nnx.Module``) are updated statefully. To update batch statistics in NNX, you first call ``model.train()``, which will set the batchnorm layer's ``use_running_average`` attribute to ``False`` (conversely, calling ``model.eval()`` would set ``use_running_average`` to ``True``). Since the stateful ``nnx.Module`` already contains the parameters and batch statistics, you simply need to call the ``nnx.Module`` to run the forward pass. Use ``nnx.split`` to extract the learnable parameters (all learnable parameters subclass the Flax ``nnx.Param`` class), and then apply the gradients and statefully update the model using ``nnx.update``. .. codediff:: :title: Haiku, Linen, NNX @@ -726,8 +673,7 @@ extract the learnable parameters (all learnable parameters subclass the NNX clas train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) -To use ``TrainState``, we subclass to add an additional field that can store -the batch statistics: +To use ``TrainState``, you subclass to add an additional field that can store the batch statistics: .. codediff:: :title: Haiku, Linen, NNX @@ -844,17 +790,13 @@ the batch statistics: train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) -Using Multiple Methods ------------------------ +Using multiple methods +---------------------- -In this section we will take a look at how to use multiple methods in all three -frameworks. As an example, we will implement an auto-encoder model with three methods: -``encode``, ``decode``, and ``__call__``. +This section examines how to use multiple methods in all three frameworks using an implementation of an auto-encoder model with three methods - ``encode``, ``decode``, and ``__call__`` - as an example. -As before, we define the encoder and decoder layers without having to pass in the -input shape, since the module parameters will be initialized lazily using shape -inference in Haiku and Linen. In NNX, we must pass in the input shape -since the module parameters will be initialized eagerly without shape inference. +- In Haiku and Flax Linen, as before, you define the encoder and decoder layers without having to pass in the input shape, since the ``Module`` parameters will be initialized lazily using shape inference. +- In Flax NNX, you must pass in the input shape since the ``nnx.Module`` parameters will be initialized eagerly without shape inference. .. codediff:: :title: Haiku, Linen, NNX @@ -921,7 +863,7 @@ since the module parameters will be initialized eagerly without shape inference. x = self.decode(x) return x -As before, we pass in the input shape when instantiating the NNX module. +As before, in Flax NNX you pass in the input shape when instantiating the ``nnx.Module``. .. codediff:: :title: Haiku, Linen, NNX @@ -953,11 +895,9 @@ As before, we pass in the input shape when instantiating the NNX module. model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0)) -For Haiku and Linen, ``init`` can be used to trigger the -``__call__`` method to initialize the parameters of our model, -which uses both the ``encode`` and ``decode`` method. This will -create all the necessary parameters for the model. In NNX, -the parameters are already initialized upon module instantiation. +Initializing the parameters: +- For Haiku and Flax Linen, ``init`` can be used to trigger the ``__call__`` method to initialize the parameters of your model, which uses both the ``encode`` and ``decode`` method. This will create all necessary parameters for the model. +- In Flax NNX, the parameters are already initialized upon model instantiation. .. codediff:: :title: Haiku, Linen, NNX @@ -977,7 +917,7 @@ the parameters are already initialized upon module instantiation. --- - # parameters were already initialized during model instantiation + # Parameters were already initialized during model instantiation. ... @@ -1044,10 +984,9 @@ The parameter structure is as follows: }) -Finally, let's explore how we can employ the forward pass. In Haiku -and Linen, we use the ``apply`` function to invoke the ``encode`` -method. In NNX, we simply can simply call the ``encode`` method -directly. +Finally, let's explore how to employ the forward pass: +- In Haiku and Flax Linen, use the ``apply`` function to invoke the ``encode`` method. +- In Flax NNX, you can simply call the ``encode`` method directly. .. codediff:: :title: Haiku, Linen, NNX @@ -1082,19 +1021,14 @@ directly. ... -Lifted Transforms ------------------ +Lifted transformations +---------------------- -Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms, -that wrap JAX transformations in such a way that they can be used with Modules and sometimes -provide additional functionality. In this section we will take a look at how to use the -lifted version of ``scan`` in both Flax and Haiku to implement a simple RNN layer. +Flax (`Linen`__ and `NNX`__) and `Haiku`__ provide a set of transforms, which are referred to as lifted transformations. These transforms wrap `JAX transformations`__ in such a way that they can be used with ``Module``s, and sometimes provide additional functionality. -To begin, we will first define a ``RNNCell`` module that will contain the logic for a single -step of the RNN. We will also define a ``initial_state`` method that will be used to initialize -the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan``, the ``RNNCell.__call__`` -method will be a function that takes the carry and input, and returns the new -carry and output. In this case, the carry and the output are the same. +This section examines how to use the lifted version of ``scan`` in Flax (Linen and NNX) and Haiku to implement a simple recurrent neural network (RNN) layer. + +Start with defining the ``RNNCell`` ``Module`` that will contain the logic for a single step of the RNN. In addition, define the ``initial_state`` method that will be used to initialize the state (a.k.a. ``carry``) of the RNN. Similar to ``jax.lax.scan`` (`API`__), the ``RNNCell.__call__`` method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same. .. codediff:: :title: Haiku, Linen, NNX @@ -1146,21 +1080,10 @@ carry and output. In this case, the carry and the output are the same. def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) -Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. -In Haiku, we will first initialze the ``RNNCell``, then use it to construct the ``carry``, -and finally use ``hk.scan`` to run the ``RNNCell`` over the input sequence. - -In Linen, we will use ``nn.scan`` to define a new temporary type that wraps -``RNNCell``. During this process we will also specify instruct ``nn.scan`` to broadcast -the ``params`` collection (all steps share the same parameters) and to not split the -``params`` rng stream (so all steps intialize with the same parameters), and finally -we will specify that we want scan to run over the second axis of the input and stack -the outputs along the second axis as well. We will then use this temporary type immediately -to create an instance of the lifted ``RNNCell`` and use it to create the ``carry`` and -the run the ``__call__`` method which will ``scan`` over the sequence. - -In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined -in ``__init__`` to scan over the sequence. +Next, define the ``RNN`` ``Module`` that will contain the logic for the entire RNN: +- In Haiku, you first initialize the ``RNNCell``, then use it to construct the ``carry``, and then use ``hk.scan`` to run the ``RNNCell`` over the input sequence. +- In Flax Linen, you use ``flax.linen.scan`` (``nn.scan``) to define a new temporary type that wraps ``RNNCell``. During this process, specify the instruct ``nn.scan`` to broadcast the ``params`` collection (all steps share the same parameters) and to not split the ``params`` PRNG stream (so that all steps initialize with the same parameters); and, finally, specify that you want ``scan`` to run over the second axis of the input and stack the outputs along the second axis as well. You will then use this temporary type immediately to create an instance of the lifted ``RNNCell``, and use it to create the ``carry`` and the run the ``__call__`` method, which will ``scan`` over the sequence. +- In Flax NNX, you define a custom scan function ``scan_fn`` that will use the ``RNNCell`` defined in ``__init__`` to scan over the sequence. .. codediff:: :title: Haiku, Linen, NNX @@ -1214,17 +1137,13 @@ in ``__init__`` to scan over the sequence. return y -In general, the main difference between lifted transforms between Flax and Haiku is that -in Haiku the lifted transforms don't operate over the state, that is, Haiku will handle the -``params`` and ``state`` in such a way that it keeps the same shape inside and outside of the -transform. In Flax, the lifted transforms can operate over both variable collections and rng -streams, the user must define how different collections are treated by each transform -according to the transform's semantics. +In general, the main difference between lifted transforms between Flax (Linen and NNX) and Haiku is as follows: +- In Haiku, the lifted transforms don't operate over the state. Tat is, Haiku will handle the ``params`` and ``state`` in such a way that it keeps the same shape inside and outside of the transform. +- In Flax, the lifted transforms can operate over both variable collections and rng streams, the user must define how different collections are treated by each transform according to the transform's semantics. -As before, the parameters must be initialized via ``.init()`` and passed into ``.apply()`` -to conduct a forward pass in Haiku and Linen. In NNX, the parameters are already -eagerly initialized and bound to the stateful module, and the module can be simply called -on the input to conduct a forward pass. +Initializing the parameters: +- In Haiku and Flax Linen, as before, the parameters must be initialized via ``.init()`` and passed into ``.apply()`` to conduct a forward pass in Haiku and Flax Linen. +- In Flax NNX, the parameters are already eagerly initialized and bound to the stateful ``Module``, and the ``Module`` can be simply called on the input to conduct a forward pass. .. codediff:: :title: Haiku, Linen, NNX @@ -1285,33 +1204,18 @@ on the input to conduct a forward pass. ... -The only notable change with respect to the examples in the previous sections is that -this time around we used ``hk.without_apply_rng`` in Haiku so we didn't have to -pass the ``rng`` argument as ``None`` to the ``apply`` method. +The Haiku example contains the only notable change compared with the examples in the previous sections: +- Here you used ``hk.without_apply_rng``, so that you didn't have to pass the ``rng`` argument as ``None`` to the ``apply`` method. Scan over layers ---------------- -One very important application of ``scan`` is apply a sequence of layers iteratively -over an input, passing the output of each layer as the input to the next layer. This -is very useful to reduce compilation time for big models. As an example we will create -a simple ``Block`` Module, and then use it inside an ``MLP`` Module that will apply -the ``Block`` Module ``num_layers`` times. - -In Haiku, we define the ``Block`` Module as usual, and then inside ``MLP`` we will -use ``hk.experimental.layer_stack`` over a ``stack_block`` function to create a stack -of ``Block`` Modules. - -In Linen, the definition of ``Block`` is a little different, -``__call__`` will accept and return a second dummy input/output that in both cases will -be ``None``. In ``MLP``, we will use ``nn.scan`` as in the previous example, but -by setting ``split_rngs={'params': True}`` and ``variable_axes={'params': 0}`` -we are telling ``nn.scan`` create different parameters for each step and slice the -``params`` collection along the first axis, effectively implementing a stack of -``Block`` Modules as in Haiku. - -In NNX, we use ``nnx.Scan.constructor()`` to define a stack of ``Block`` modules. -We can then simply call the stack of ``Block``'s, ``self.blocks``, on the input and -carry to get the forward pass output. + +One very important application of ``scan`` is to apply a sequence of layers iteratively over an input, passing the output of each layer as the input to the next layer. This is very useful to reduce compilation time for large models. + +As an example, let’s create a simple ``Block`` ``Module``, and then use it inside an ``MLP`` ``Module`` that will apply the ``Block`` ``Module`` ``num_layers`` times: +- In Haiku, define the ``Block`` ``Module`` as usual, and then inside the ``MLP`` you use ``hk.experimental.layer_stack`` over a ``stack_block`` function to create a stack of ``Block`` Modules. +- In Flax Linen, the definition of ``Block`` is a little different. ``__call__`` will accept and return a second dummy input/output that in both cases will be ``None``. In he ``MLP``, use ``nn.scan`` similar to the previous example, but by setting ``split_rngs={'params': True}`` and ``variable_axes={'params': 0}`` you instruct ``nn.scan`` to create different parameters for each step, and slice the ``params`` collection along the first axis effectively implementing a stack of ``Block`` ``Module``s, similar to Haiku. +- In Flax NNX, use ``nnx.Scan.constructor()`` to define a stack of ``Block`` ``Module``s. You can then simply call the stack of ``Block``'s, ``self.blocks``, on the input and carry to get the forward pass output. .. codediff:: :title: Haiku, Linen, NNX @@ -1402,13 +1306,10 @@ carry to get the forward pass output. y, _ = self.blocks(x, None) return y -Notice how in Flax we pass ``None`` as the second argument to ``ScanBlock`` and ignore -its second output. These represent the inputs/outputs per-step but they are ``None`` -because in this case we don't have any. +**Note:** Notice how in Flax Linen and NNX, you pass ``None`` as the second argument to ``ScanBlock`` and ignore its second output. These normally represent the inputs/outputs per-step but here they are ``None`` because in this case you don't have any. -Initializing each model is the same as in previous examples. In this case, -we will be specifying that we want to use ``5`` layers each with ``64`` features. -As before, we also pass in the input shape for NNX. +Next, initializing each model is the same as in previous examples. In this example, you will specify that you want to use ``5`` layers each with ``64`` features: +- In Flax NNX, as before, you also pass in the input shape. .. codediff:: :title: Haiku, Linen, NNX @@ -1451,10 +1352,7 @@ As before, we also pass in the input shape for NNX. ... -When using scan over layers the one thing you should notice is that all layers -are fused into a single layer whose parameters have an extra "layer" dimension on -the first axis. In this case, the shape of all parameters will start with ``(5, ...)`` -as we are using ``5`` layers. +When using ``scan`` over layers, one thing you should notice is that all layers are fused into a single layer whose parameters have an extra "layer" dimension on the first axis. In this case, the shape of all parameters will start with ``(5, ...)`` since you are using ``5`` layers: .. tab-set:: @@ -1518,15 +1416,14 @@ as we are using ``5`` layers. Top-level Haiku functions vs top-level Flax modules ----------------------------------- -In Haiku, it is possible to write the entire model as a single function by using -the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and -states. It very common to write the top-level "Module" as a function instead. +Top-level Haiku functions vs top-level Flax ``Module``s +------------------------------------------------------- + +In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and states. It is very common to write the top-level "Module" as a function instead. -The Flax team recommends a more Module-centric approach that uses ``__call__`` to -define the forward function. In Linen, the corresponding accessor will be -``Module.param`` and ``Module.variable`` (go to `Handling State <#handling-state>`__ -for an explanation on collections). In NNX, the parameters and variables can -be set and accessed as normal using regular Python class semantics. +The Flax team recommends a more ``Module``-centric approach that uses ``__call__`` to define the forward function: +- In Flax Linen, the corresponding accessor will be ``Module.param`` and ``Module.variable`` (go to `Handling State <#handling-state>`__ for an explanation on collections). +- In Flax NNX, the parameters and variables can be set and accessed as normal using regular Python class semantics. .. codediff:: :title: Haiku, Linen, NNX @@ -1566,7 +1463,7 @@ be set and accessed as normal using regular Python class semantics. ) output = x + multiplier * counter.value - if not self.is_initializing(): # otherwise model.init() also increases it + if not self.is_initializing(): # Otherwise model.init() also increases it counter.value += 1 return output @@ -1594,4 +1491,4 @@ be set and accessed as normal using regular Python class semantics. model = FooModule(rngs=nnx.Rngs(0)) - _, params, counter = nnx.split(model, nnx.Param, Counter) \ No newline at end of file + _, params, counter = nnx.split(model, nnx.Param, Counter)