diff --git a/docs_nnx/guides/haiku_linen_vs_nnx.rst b/docs_nnx/guides/haiku_linen_vs_nnx.rst index 79d318edc0..7ef815d3a1 100644 --- a/docs_nnx/guides/haiku_linen_vs_nnx.rst +++ b/docs_nnx/guides/haiku_linen_vs_nnx.rst @@ -1,11 +1,11 @@ -Migrating from Haiku/Linen to NNX -================================= +Migrate to Flax NNX from Haiku/Flax Linen +========================================= -This guide will showcase the differences between Haiku, Flax Linen and Flax NNX. -Both Haiku and Linen enforce a functional paradigm with stateless modules, -while NNX is a new, next-generation API that embraces the python language to -provide a more intuitive development experience. +This guide demonstrates the differences between the Flax NNX API (the next-generation Flax API), +Haiku, and Flax Linen. Both Haiku and Flax Linen enforce a functional paradigm with stateless +`Module`s, while Flax NNX embraces the Python language to provide a more intuitive development +experience. .. testsetup:: Haiku, Linen, NNX @@ -21,18 +21,19 @@ provide a more intuitive development experience. # TODO: make sure code lines are not too long # TODO: make sure all code diffs are aligned -Basic Example ------------------ +A basic example +--------------- -To create custom Modules you subclass from a ``Module`` base class in -both Haiku and Flax. Modules can be defined inline in Haiku and Flax -Linen (using the ``@nn.compact`` decorator), whereas modules can't be -defined inline in NNX and must be defined in ``__init__``. +To create custom `Module`s, in both Haiku and Flax you subclass from the ``Module`` +base class. Note that: -Linen requires a ``deterministic`` argument to control whether or -not dropout is used. NNX also uses a ``deterministic`` argument -but the value can be set later using ``.eval()`` and ``.train()`` methods -that will be shown in a later code snippet. +- In Haiku and Flax Linen, `Module`s can be defined inline using the ``@nn.compact`` decorator). +- In Flax NNX, `Module`s can't be defined inline in NNX but instead must be defined in ``__init__``. + +In addition: + +- Flax Linen requires a ``deterministic`` argument to control whether or not dropout is used. +- Flax NNX also uses a ``deterministic`` argument but the value can be set later using ``.eval()`` and ``.train()`` methods that will be shown in a later code snippet. .. codediff:: :title: Haiku, Linen, NNX @@ -114,11 +115,8 @@ that will be shown in a later code snippet. x = self.linear(x) return x -Since modules are defined inline in Haiku and Linen, the parameters -are lazily initialized, by inferring the shape of a sample input. In Flax -NNX, the module is stateful and is initialized eagerly. This means that the -input shape must be explicitly passed during module instantiation since there -is no shape inference in NNX. +- In Haiku and Flax Linen, since `Module`s are defined inline, the parameters are lazily initialized by inferring the shape of a sample input. +- In Flax NNX, the `Module` is stateful and is initialized eagerly. This means that the input shape must be explicitly passed during module instantiation since there is no shape inference in NNX. .. codediff:: :title: Haiku, Linen, NNX @@ -146,11 +144,11 @@ is no shape inference in NNX. To get the model parameters in both Haiku and Linen, you use the ``init`` method with a ``random.key`` plus some inputs to run the model. -In NNX, the model parameters are automatically initialized when the user +In Flax NNX, the model parameters are automatically initialized when the user instantiates the model because the input shapes are already explicitly passed at instantiation time. -Since NNX is eager and the module is bound upon instantiation, the user can access +Since Flax NNX is eager and the module is bound upon instantiation, the user can access the parameters (and other fields defined in ``__init__`` via dot-access). On the other hand, Haiku and Linen use lazy initialization and so the parameters can only be accessed once the module is initialized with a sample input and both frameworks do not support @@ -193,16 +191,9 @@ dot-access of their attributes. assert model.linear.bias.value.shape == (10,) assert model.block.linear.kernel.value.shape == (784, 256) -Let's take a look at the parameter structure. In Haiku and Linen, we can -simply inspect the ``params`` object returned from ``.init()``. - -To see the parameter structure in NNX, the user can call ``nnx.split`` to -generate ``Graphdef`` and ``State`` objects. The ``Graphdef`` is a static pytree -denoting the structure of the model (for example usages, see -`NNX Basics `__). -``State`` objects contains all the module variables (i.e. any class that sub-classes -``nnx.Variable``). If we filter for ``nnx.Param``, we will generate a ``State`` object -of all the learnable module parameters. +Let's take a look at the parameter structure: +- In Haiku and Flax Linen, you would simply inspect the ``params`` object returned from ``.init()``. +- In Flax NNX, to view the parameter structure, you can call ``nnx.split`` to generate ``Graphdef`` and ``State`` objects. The ``Graphdef`` is a static pytree denoting the structure of the model (for example usages, check the `NNX basics guide `__). ``State`` objects contains all the module variables (i.e. any class that subclasses ``nnx.Variable``). If we filter for ``nnx.Param``, we will generate a ``State`` object of all the learnable module parameters. .. tab-set:: @@ -271,22 +262,13 @@ of all the learnable module parameters. } }) -During training in Haiku and Linen, you pass the parameters structure to the -``apply`` method to run the forward pass. To use dropout, we must pass in -``training=True`` and provide a ``key`` to ``apply`` in order to generate the -random dropout masks. To use dropout in NNX, we first call ``model.train()``, -which will set the dropout layer's ``deterministic`` attribute to ``False`` -(conversely, calling ``model.eval()`` would set ``deterministic`` to ``True``). -Since the stateful NNX module already contains both the parameters and RNG key -(used for dropout), we simply need to call the module to run the forward pass. We -use ``nnx.split`` to extract the learnable parameters (all learnable parameters -subclass the NNX class ``nnx.Param``) and then apply the gradients and statefully -update the model using ``nnx.update``. - -To compile ``train_step``, we decorate the function using ``@jax.jit`` for Haiku -and Linen, and ``@nnx.jit`` for NNX. Similar to ``@jax.jit``, ``@nnx.jit`` also -compiles functions, with the additional feature of allowing the user to compile -functions that take in NNX modules as arguments. +During training: +- in Haiku and Flax Linen, you pass the parameters structure to the ``apply`` method to run the forward pass. To use dropout, we must pass in ``training=True`` and provide a ``key`` to ``apply`` to generate the random dropout masks. To use dropout in NNX, we first call ``model.train()``, which will set the dropout layer's ``deterministic`` attribute to ``False`` (conversely, calling ``model.eval()`` would set ``deterministic`` to ``True``). +- In Flax NNX, since the stateful NNX `Module` already contains both the parameters and RNG key (used for dropout), you simply need to call the module to run the forward pass. You use ``nnx.split`` to extract the learnable parameters (all learnable parameters subclass the NNX class ``nnx.Param``) and then apply the gradients and statefully update the model using ``nnx.update``. + +To compile the ``train_step``: +- Haiku and Flax Linen, you decorate the function using ``@jax.jit``. +- In Flax Linen, you decorate it with ``@nnx.jit``. Similar to ``@jax.jit``, ``@nnx.jit`` also compiles functions, with the additional feature of allowing the user to compile functions that take in NNX modules as arguments. .. codediff:: :title: Haiku, Linen, NNX @@ -379,7 +361,7 @@ One thing to note is that ``GraphDef.apply`` will take in ``State``'s as argumen return a callable function. This function can be called on the inputs to output the model's logits, as well as updated ``GraphDef`` and ``State`` objects. This isn't needed for our current example with dropout, but in the next section, you will see that using -these updated objects are relevant with layers like batch norm. Notice we also use +these updated objects are relevant with layers like batch normalization. Notice the use of ``@jax.jit`` since we aren't passing in NNX modules into ``train_step``. .. codediff:: @@ -1402,13 +1384,13 @@ carry to get the forward pass output. y, _ = self.blocks(x, None) return y -Notice how in Flax we pass ``None`` as the second argument to ``ScanBlock`` and ignore -its second output. These represent the inputs/outputs per-step but they are ``None`` -because in this case we don't have any. +Notice how in Flax, you pass ``None`` as the second argument to ``ScanBlock`` and ignore +its second output. These normally represent the inputs/outputs per-step but here they are +``None`` because in this case you don't have any. -Initializing each model is the same as in previous examples. In this case, -we will be specifying that we want to use ``5`` layers each with ``64`` features. -As before, we also pass in the input shape for NNX. +Initializing each model is the same as in previous examples. In this example, +you will specify that you want to use ``5`` layers each with ``64`` features. +As before, you also pass in the input shape for NNX. .. codediff:: :title: Haiku, Linen, NNX @@ -1594,4 +1576,6 @@ be set and accessed as normal using regular Python class semantics. model = FooModule(rngs=nnx.Rngs(0)) - _, params, counter = nnx.split(model, nnx.Param, Counter) \ No newline at end of file + _, params, counter = nnx.split(model, nnx.Param, Counter) + +