diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb
index edf8ef8e0f..35ed4d6f52 100644
--- a/docs_nnx/guides/transforms.ipynb
+++ b/docs_nnx/guides/transforms.ipynb
@@ -5,16 +5,18 @@
"id": "962be290",
"metadata": {},
"source": [
- "# Transforms\n",
- "JAX transformations in general operate on [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of arrays\n",
- "and abide by value semantics, this presents a challenge for Flax NNX which represents Modules as regular Python objects\n",
- "that follow reference semantics. To address this, Flax NNX introduces its own set of transformations that extend JAX\n",
- "transformations to allow Modules and other Flax NNX objects to be passed in and out of transformations while preserving\n",
+ "# Transformations\n",
+ "\n",
+ "In general, JAX transformations (transforms) operate on [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of `jax.Array`s\n",
+ "and abide by value semantics. This presents a challenge for Flax NNX, which represents `nnx.Module`s as regular Python objects\n",
+ "that follow reference semantics. To address this, Flax NNX introduced its own set of transforms that extend JAX\n",
+ "transforms to allow `nnx.Module`s and other Flax NNX objects to be passed in and out of transforms while preserving\n",
"reference semantics.\n",
"\n",
- "Flax NNX transformations should feel quite familar to those who have used JAX transformations before as they use the\n",
- "same APIs and behave like the JAX transformations when only working with Pytrees of arrays. However, when working with\n",
+ "Flax NNX transforms should feel quite familiar if you have used JAX transforms before. They use the\n",
+ "same APIs and behave like the JAX transforms when only working with pytrees of `jax.Array`s. However, when working with\n",
"Flax NNX objects, they allow Python's reference semantics to be preserved for these objects, this includes:\n",
+ "\n",
"* Preserving shared references across multiple objects in the inputs and outputs of the transformation.\n",
"* Propagating any state changes made to the objects inside the transformation to the objects outside the transformation.\n",
"* Enforcing consistency of how objects are transformed when aliases are present across multiple inputs and outputs."
@@ -37,11 +39,12 @@
"id": "b44fb248",
"metadata": {},
"source": [
- "Throughout this guide we will use `nnx.vmap` as a case study to demonstrate how Flax NNX transformations work but the principles\n",
- "outlined here extend to all transformations.\n",
+ "Throughout this guide, `nnx.vmap` is used as a case study to demonstrate how Flax NNX transforms work. However, the principles\n",
+ "outlined in this document extends to all transforms.\n",
"\n",
- "## Basic Example\n",
- "To begin, let's look at a simple example of using `nnx.vmap` to extend an elementwise `vector_dot` function to work on\n",
+ "## Basic example\n",
+ "\n",
+ "To begin, let's look at a simple example of using `nnx.vmap` to extend an element wise `vector_dot` function to work on\n",
"batched inputs. We will define a `Weights` Module with no methods to hold some parameters, these weights will be passed\n",
"as an input to the `vector_dot` function along with some data. Both the weights and data will be batched on axis `0` and we will use\n",
"`nnx.vmap` to apply `vector_dot` to each batch element, and the result will be a batched on axis `1`:"
@@ -112,10 +115,10 @@
"id": "d2b222eb",
"metadata": {},
"source": [
- "Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it where a Pytree of arrays. Prefix patterns are also allowed, `in_axes=(0, 0)` would've also worked in this case.\n",
+ "Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it were a pytree of `jax.Array`s. Prefix patterns are also allowed, so `in_axes=(0, 0)` would have also worked in this case.\n",
"\n",
- "Objects are also allowed as outputs of Flax NNX transformations, this can be useful to transform initializers. For example,\n",
- "we can define a `create_weights` function to create an single `Weights` Module and use `nnx.vmap` to create a stack of\n",
+ "Objects are also allowed as outputs of Flax NNX transforms, which can be useful to transform initializers. For example,\n",
+ "you can define a `create_weights` function to create an single `Weights` `nnx.Module`, and use `nnx.vmap` to create a stack of\n",
"`Weights` with the same shapes as before:"
]
},
@@ -167,8 +170,9 @@
"id": "fac3dca9",
"metadata": {},
"source": [
- "## Transforming Methods\n",
- "Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `vmap` to do the work of `vector_dot`:"
+ "## Transforming methods\n",
+ "\n",
+ "Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `@nnx.vmap` to do the work of `vector_dot`:"
]
},
{
@@ -236,7 +240,7 @@
"id": "13b52d61",
"metadata": {},
"source": [
- "Throughout the rest of the guide we will focus on transforming individual functions, however, note all examples can easily be written in this method style."
+ "The rest of the guide will focus on transforming individual functions. But do note that all examples can be written in this method style."
]
},
{
@@ -245,8 +249,9 @@
"metadata": {},
"source": [
"## State propagation\n",
- "So far our functions have been stateless. However, the real power of Flax NNX transformations comes when we have stateful functions since one of their main features is to propagate state changes to preserve reference semantics. Let's update our example by adding\n",
- "a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function."
+ "\n",
+ "So far our functions have been stateless. However, the real power of Flax NNX transforms comes when you have stateful functions, because one of their main features is to propagate state changes to preserve reference semantics. Let's update the previous example by adding\n",
+ "a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function:"
]
},
{
@@ -300,7 +305,7 @@
"id": "322312ee",
"metadata": {},
"source": [
- "After running `stateful_vector_dot` once we verify that the `count` attribute was correctly updated. Because Weights is vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by 1 inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!"
+ "After running `stateful_vector_dot` once, you verified that the `count` attribute was correctly updated. Because `Weights` was vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by `1` inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!"
]
},
{
@@ -309,9 +314,12 @@
"metadata": {},
"source": [
"### Graph updates propagation\n",
- "JAX transformations see inputs as pytrees of arrays, and Flax NNX see inputs pytrees of arrays and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported). This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing Variables between objects, etc. The sky is the limit!\n",
"\n",
- "The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap` and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation."
+ "JAX transforms see inputs as pytrees of `jax.Array`s, and Flax NNX sees inputs as pytrees of `jax.Array`s and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported).\n",
+ "\n",
+ "This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing `nnx.Variable`s between objects, etc. Sky is the limit!\n",
+ "\n",
+ "The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap`, and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation:"
]
},
{
@@ -383,7 +391,7 @@
"> With great power comes great responsibility.\n",
">
\\- Uncle Ben\n",
"\n",
- "While this feature is very powerful, it must be used with care as it can clash with JAX's underlying assumptions for certain transformations. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside a `nnx.jit`-ed function cause continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing substates declared as carry will cause an error."
+ "While this feature is very powerful, it must be used with care because it can clash with JAX's underlying assumptions for certain transforms. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside an `nnx.jit`-ed function causes continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing sub-states declared as carry will cause an error."
]
},
{
@@ -391,21 +399,22 @@
"id": "0d11d191",
"metadata": {},
"source": [
- "## Transforming Substates (Lift Types)\n",
+ "## Transforming sub-states (lift types)\n",
"\n",
- "Certain JAX transformation allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of `Lift Types` which allow specifying how different substates of an object should be transformed. Different transformations support different Lift Types, here is the list of currently supported Lift Types for each transformation:\n",
+ "Certain JAX transforms allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of “lift types” which allow specifying how different sub-states of an object should be transformed. Different transforms support different lift types, here is the list of currently supported FLax NNX lift types for each JAX transformation:\n",
"\n",
- "| Lift Type | Transforms |\n",
+ "| Lift type | JAX transforms |\n",
"|------------------|-----------------------------------------|\n",
"| `StateAxes` | `vmap`, `pmap`, `scan` |\n",
- "| `StateSharding` | `jit`, `shard_map` |\n",
+ "| `StateSharding` | `jit`, `shard_map`* |\n",
"| `DiffState` | `grad`, `value_and_grad`, `custom_vjp` |\n",
"\n",
- "> NOTE: `shard_map` is not yet implemented.\n",
+ "> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document.\n",
+ "\n",
+ "To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.\n",
"\n",
- "If we want to specify how to vectorize different substates of an object in `nnx.vmap`, we create a `StateAxes` which maps a set of substates via [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and pass the `StateAxes` to `in_axes` and `out_axes` as if it were a pytree prefix. Let's use the previous `stateful_vector_dot` example and\n",
- "vectorize only the `Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n",
- "To do this we will define a `StateAxes` with a filter that matches the `Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `StateAxes` to `in_axes` for the `Weights` object."
+ "Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n",
+ "To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object."
]
},
{
@@ -458,7 +467,7 @@
"id": "1cfd87e1",
"metadata": {},
"source": [
- "Here count is now a scalar since its not being vectorized. Also, note that `StateAxes` can only be used directly on Flax NNX objects, it cannot be used as a prefix for a pytree of objects."
+ "Here, `count` is now a scalar since it's not being vectorized. Also, note that `nnx.StateAxes` can only be used directly on Flax NNX objects, and it cannot be used as a prefix for a pytree of objects."
]
},
{
@@ -466,10 +475,13 @@
"id": "1c8bb104",
"metadata": {},
"source": [
- "### Random State\n",
- "In Flax NNX random state is just regular state. This means that its stored inside Modules that need it and its treated as any other type of state. This is a simplification over Flax Linen where random state was handled by a separate mechanism. In practice Modules simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly.\n",
+ "### Random state\n",
"\n",
- "Let's suppose we want change things up a bit and apply the same weights to all elements in the batch but we want to add different random noise to each element. To do this we will add a `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction, this seed key must be `split` before hand so we can vectorize it succesfully. For pedagogical reasons, we will assign the seed key to a `noise` Stream and sample from it. To vectorize the RNG state we must configure `StateAxes` to map all `RngState` (base class for all variables in `Rngs`) to axis `0`, and `Param` and `Count` to `None`."
+ "In Flax NNX, a random state is just a regular state. This means that it is stored inside `nnx.Module`s that need it, and it is treated as any other type of state. This is a simplification over Flax Linen, where a random state was handled by a separate mechanism. In practice `nnx.Module`s simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly.\n",
+ "\n",
+ "Suppose you want to change things up a bit and apply the same weights to all elements in the batch. But you also want to add different random noise to each element.\n",
+ "\n",
+ "To do this, you will add an `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction. This seed key must be `split` beforehand, so that you can vectorize it successfully. For pedagogical reasons, you will assign the seed key to a `noise` “stream” and sample from it. To vectorize the PRNG state, you must configure `nnx.StateAxes` to map all `RngState`s (a base class for all variables in `Rngs`) to axis `0`, and `nnx.Param` and `Count` to `None`."
]
},
{
@@ -547,7 +559,7 @@
"source": [
"Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called.\n",
"\n",
- "In the example above we manually split the random state during construction, this is fine as it makes the intention clear but it also doesn't let us use `Rngs` outside of `vmap` since its state is always split. To solve this we pass an unplit seed and use the `nnx.split_rngs` decorator before `vmap` to split the `RngState` right before each call to the function and then \"lower\" it back so its usable."
+ "In the example above, you manually split the random state during construction. This is fine, as it makes the intention clear, but it also doesn't let you use `Rngs` outside of `nnx.vmap` because its state is always split. To solve this, you can pass an unsplit seed and use the `nnx.split_rngs` decorator before `nnx.vmap` to split the `RngState` right before each call to the function, and then \"lower\" it back so that it becomes usable."
]
},
{
@@ -621,7 +633,8 @@
"metadata": {},
"source": [
"## Consistent aliasing\n",
- "The main issue with allowing for reference semantics in transforms that references can be shared across inputs and outputs, this can be problematic if not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below we have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem is that we also specified we wanted to vectorize `arg1` in axis `0` and `arg2` in axis `1`, this is fine in JAX due to referential transparency of pytrees but its problematic in Flax NNX since we are trying to vectorize `m` in two different ways. NNX will enforce consistency by raising an error."
+ "\n",
+ "The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error."
]
},
{
@@ -670,7 +683,7 @@
"id": "46aa978c",
"metadata": {},
"source": [
- "Inconsistent aliasing can also happen between inputs and outputs. In the next example we have a trivial function that accepts and immediately return `arg1`, however `arg1` is vectorized on axis `0` on the input and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error."
+ "Inconsistent aliasing can also happen between inputs and outputs. In the next example you have a trivial function that accepts and immediately returns `arg1`. However, `arg1` is vectorized on axis `0` on the input, and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error."
]
},
{
@@ -711,8 +724,15 @@
"id": "13f9aeea",
"metadata": {},
"source": [
- "## Axes Metadata\n",
- "Flax NNX Variables can have hold arbitrary metadata which can be added by simply passing them as keyword arguments to their constructor. This is often used to store `sharding` information which is used by the `nnx.spmd` APIs like `nnx.get_partition_spec` and `nnx.get_named_sharding`. However, its often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved, for example, if we vectorize a variable on axis `1` we should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes is temporarily removed. To achieve this Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument, when the `nnx.PARTITION_NAME` key is present the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`. Let's see an example of this in action:"
+ "## Axis metadata\n",
+ "\n",
+ "Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`).\n",
+ "\n",
+ "However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.\n",
+ "\n",
+ "To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`.\n",
+ "\n",
+ "Let's see an example of this in action:"
]
},
{
@@ -754,9 +774,9 @@
"id": "a23bda09",
"metadata": {},
"source": [
- "Here we added a `sharding` metadata to the `Param` variables and used `transform_metadata` to update the `sharding` metadata to reflect the axes changes, specifically we can see that first axis `b` was removed from the `sharding` metadata when inside `vmap` and then added back when outside `vmap`.\n",
+ "Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.\n",
"\n",
- "We can verify that this also works when Modules are created inside the transformation, the new `sharding` axes will be added to the Module's Variables outside the transformation, matching the axes of the transformed Variables."
+ "You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s."
]
},
{
diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md
index 1b185e5006..4b12869b60 100644
--- a/docs_nnx/guides/transforms.md
+++ b/docs_nnx/guides/transforms.md
@@ -10,16 +10,18 @@ jupytext:
jupytext_version: 1.13.8
---
-# Transforms
-JAX transformations in general operate on [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of arrays
-and abide by value semantics, this presents a challenge for Flax NNX which represents Modules as regular Python objects
-that follow reference semantics. To address this, Flax NNX introduces its own set of transformations that extend JAX
-transformations to allow Modules and other Flax NNX objects to be passed in and out of transformations while preserving
+# Transformations
+
+In general, JAX transformations (transforms) operate on [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of `jax.Array`s
+and abide by value semantics. This presents a challenge for Flax NNX, which represents `nnx.Module`s as regular Python objects
+that follow reference semantics. To address this, Flax NNX introduced its own set of transforms that extend JAX
+transforms to allow `nnx.Module`s and other Flax NNX objects to be passed in and out of transforms while preserving
reference semantics.
-Flax NNX transformations should feel quite familar to those who have used JAX transformations before as they use the
-same APIs and behave like the JAX transformations when only working with Pytrees of arrays. However, when working with
+Flax NNX transforms should feel quite familiar if you have used JAX transforms before. They use the
+same APIs and behave like the JAX transforms when only working with pytrees of `jax.Array`s. However, when working with
Flax NNX objects, they allow Python's reference semantics to be preserved for these objects, this includes:
+
* Preserving shared references across multiple objects in the inputs and outputs of the transformation.
* Propagating any state changes made to the objects inside the transformation to the objects outside the transformation.
* Enforcing consistency of how objects are transformed when aliases are present across multiple inputs and outputs.
@@ -30,11 +32,12 @@ from jax import numpy as jnp, random
from flax import nnx
```
-Throughout this guide we will use `nnx.vmap` as a case study to demonstrate how Flax NNX transformations work but the principles
-outlined here extend to all transformations.
+Throughout this guide, `nnx.vmap` is used as a case study to demonstrate how Flax NNX transforms work. However, the principles
+outlined in this document extends to all transforms.
-## Basic Example
-To begin, let's look at a simple example of using `nnx.vmap` to extend an elementwise `vector_dot` function to work on
+## Basic example
+
+To begin, let's look at a simple example of using `nnx.vmap` to extend an element wise `vector_dot` function to work on
batched inputs. We will define a `Weights` Module with no methods to hold some parameters, these weights will be passed
as an input to the `vector_dot` function along with some data. Both the weights and data will be batched on axis `0` and we will use
`nnx.vmap` to apply `vector_dot` to each batch element, and the result will be a batched on axis `1`:
@@ -61,10 +64,10 @@ print(f'{y.shape = }')
nnx.display(weights)
```
-Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it where a Pytree of arrays. Prefix patterns are also allowed, `in_axes=(0, 0)` would've also worked in this case.
+Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it were a pytree of `jax.Array`s. Prefix patterns are also allowed, so `in_axes=(0, 0)` would have also worked in this case.
-Objects are also allowed as outputs of Flax NNX transformations, this can be useful to transform initializers. For example,
-we can define a `create_weights` function to create an single `Weights` Module and use `nnx.vmap` to create a stack of
+Objects are also allowed as outputs of Flax NNX transforms, which can be useful to transform initializers. For example,
+you can define a `create_weights` function to create an single `Weights` `nnx.Module`, and use `nnx.vmap` to create a stack of
`Weights` with the same shapes as before:
```{code-cell} ipython3
@@ -79,8 +82,9 @@ weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)
```
-## Transforming Methods
-Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `vmap` to do the work of `vector_dot`:
+## Transforming methods
+
+Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `@nnx.vmap` to do the work of `vector_dot`:
```{code-cell} ipython3
class WeightStack(nnx.Module):
@@ -104,13 +108,14 @@ print(f'{y.shape = }')
nnx.display(weights)
```
-Throughout the rest of the guide we will focus on transforming individual functions, however, note all examples can easily be written in this method style.
+The rest of the guide will focus on transforming individual functions. But do note that all examples can be written in this method style.
+++
## State propagation
-So far our functions have been stateless. However, the real power of Flax NNX transformations comes when we have stateful functions since one of their main features is to propagate state changes to preserve reference semantics. Let's update our example by adding
-a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function.
+
+So far our functions have been stateless. However, the real power of Flax NNX transforms comes when you have stateful functions, because one of their main features is to propagate state changes to preserve reference semantics. Let's update the previous example by adding
+a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function:
```{code-cell} ipython3
class Count(nnx.Variable): pass
@@ -139,14 +144,17 @@ y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)
weights.count
```
-After running `stateful_vector_dot` once we verify that the `count` attribute was correctly updated. Because Weights is vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by 1 inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!
+After running `stateful_vector_dot` once, you verified that the `count` attribute was correctly updated. Because `Weights` was vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by `1` inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!
+++
### Graph updates propagation
-JAX transformations see inputs as pytrees of arrays, and Flax NNX see inputs pytrees of arrays and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported). This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing Variables between objects, etc. The sky is the limit!
-The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap` and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation.
+JAX transforms see inputs as pytrees of `jax.Array`s, and Flax NNX sees inputs as pytrees of `jax.Array`s and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported).
+
+This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing `nnx.Variable`s between objects, etc. Sky is the limit!
+
+The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap`, and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation:
```{code-cell} ipython3
class Count(nnx.Variable): pass
@@ -181,25 +189,26 @@ nnx.display(weights)
> With great power comes great responsibility.
>
\- Uncle Ben
-While this feature is very powerful, it must be used with care as it can clash with JAX's underlying assumptions for certain transformations. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside a `nnx.jit`-ed function cause continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing substates declared as carry will cause an error.
+While this feature is very powerful, it must be used with care because it can clash with JAX's underlying assumptions for certain transforms. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside an `nnx.jit`-ed function causes continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing sub-states declared as carry will cause an error.
+++
-## Transforming Substates (Lift Types)
+## Transforming sub-states (lift types)
-Certain JAX transformation allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of `Lift Types` which allow specifying how different substates of an object should be transformed. Different transformations support different Lift Types, here is the list of currently supported Lift Types for each transformation:
+Certain JAX transforms allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of “lift types” which allow specifying how different sub-states of an object should be transformed. Different transforms support different lift types, here is the list of currently supported FLax NNX lift types for each JAX transformation:
-| Lift Type | Transforms |
+| Lift type | JAX transforms |
|------------------|-----------------------------------------|
| `StateAxes` | `vmap`, `pmap`, `scan` |
-| `StateSharding` | `jit`, `shard_map` |
+| `StateSharding` | `jit`, `shard_map`* |
| `DiffState` | `grad`, `value_and_grad`, `custom_vjp` |
-> NOTE: `shard_map` is not yet implemented.
+> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document.
+
+To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.
-If we want to specify how to vectorize different substates of an object in `nnx.vmap`, we create a `StateAxes` which maps a set of substates via [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and pass the `StateAxes` to `in_axes` and `out_axes` as if it were a pytree prefix. Let's use the previous `stateful_vector_dot` example and
-vectorize only the `Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.
-To do this we will define a `StateAxes` with a filter that matches the `Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `StateAxes` to `in_axes` for the `Weights` object.
+Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.
+To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object.
```{code-cell} ipython3
class Weights(nnx.Module):
@@ -227,14 +236,17 @@ y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights,
weights.count
```
-Here count is now a scalar since its not being vectorized. Also, note that `StateAxes` can only be used directly on Flax NNX objects, it cannot be used as a prefix for a pytree of objects.
+Here, `count` is now a scalar since it's not being vectorized. Also, note that `nnx.StateAxes` can only be used directly on Flax NNX objects, and it cannot be used as a prefix for a pytree of objects.
+++
-### Random State
-In Flax NNX random state is just regular state. This means that its stored inside Modules that need it and its treated as any other type of state. This is a simplification over Flax Linen where random state was handled by a separate mechanism. In practice Modules simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly.
+### Random state
-Let's suppose we want change things up a bit and apply the same weights to all elements in the batch but we want to add different random noise to each element. To do this we will add a `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction, this seed key must be `split` before hand so we can vectorize it succesfully. For pedagogical reasons, we will assign the seed key to a `noise` Stream and sample from it. To vectorize the RNG state we must configure `StateAxes` to map all `RngState` (base class for all variables in `Rngs`) to axis `0`, and `Param` and `Count` to `None`.
+In Flax NNX, a random state is just a regular state. This means that it is stored inside `nnx.Module`s that need it, and it is treated as any other type of state. This is a simplification over Flax Linen, where a random state was handled by a separate mechanism. In practice `nnx.Module`s simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly.
+
+Suppose you want to change things up a bit and apply the same weights to all elements in the batch. But you also want to add different random noise to each element.
+
+To do this, you will add an `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction. This seed key must be `split` beforehand, so that you can vectorize it successfully. For pedagogical reasons, you will assign the seed key to a `noise` “stream” and sample from it. To vectorize the PRNG state, you must configure `nnx.StateAxes` to map all `RngState`s (a base class for all variables in `Rngs`) to axis `0`, and `nnx.Param` and `Count` to `None`.
```{code-cell} ipython3
class Weights(nnx.Module):
@@ -268,7 +280,7 @@ nnx.display(weights)
Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called.
-In the example above we manually split the random state during construction, this is fine as it makes the intention clear but it also doesn't let us use `Rngs` outside of `vmap` since its state is always split. To solve this we pass an unplit seed and use the `nnx.split_rngs` decorator before `vmap` to split the `RngState` right before each call to the function and then "lower" it back so its usable.
+In the example above, you manually split the random state during construction. This is fine, as it makes the intention clear, but it also doesn't let you use `Rngs` outside of `nnx.vmap` because its state is always split. To solve this, you can pass an unsplit seed and use the `nnx.split_rngs` decorator before `nnx.vmap` to split the `RngState` right before each call to the function, and then "lower" it back so that it becomes usable.
```{code-cell} ipython3
weights = Weights(
@@ -298,7 +310,8 @@ nnx.display(weights)
```
## Consistent aliasing
-The main issue with allowing for reference semantics in transforms that references can be shared across inputs and outputs, this can be problematic if not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below we have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem is that we also specified we wanted to vectorize `arg1` in axis `0` and `arg2` in axis `1`, this is fine in JAX due to referential transparency of pytrees but its problematic in Flax NNX since we are trying to vectorize `m` in two different ways. NNX will enforce consistency by raising an error.
+
+The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error.
```{code-cell} ipython3
class Weights(nnx.Module):
@@ -319,7 +332,7 @@ except ValueError as e:
print(e)
```
-Inconsistent aliasing can also happen between inputs and outputs. In the next example we have a trivial function that accepts and immediately return `arg1`, however `arg1` is vectorized on axis `0` on the input and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error.
+Inconsistent aliasing can also happen between inputs and outputs. In the next example you have a trivial function that accepts and immediately returns `arg1`. However, `arg1` is vectorized on axis `0` on the input, and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error.
```{code-cell} ipython3
@nnx.vmap(in_axes=0, out_axes=1)
@@ -332,8 +345,15 @@ except ValueError as e:
print(e)
```
-## Axes Metadata
-Flax NNX Variables can have hold arbitrary metadata which can be added by simply passing them as keyword arguments to their constructor. This is often used to store `sharding` information which is used by the `nnx.spmd` APIs like `nnx.get_partition_spec` and `nnx.get_named_sharding`. However, its often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved, for example, if we vectorize a variable on axis `1` we should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes is temporarily removed. To achieve this Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument, when the `nnx.PARTITION_NAME` key is present the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`. Let's see an example of this in action:
+## Axis metadata
+
+Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`).
+
+However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.
+
+To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`.
+
+Let's see an example of this in action:
```{code-cell} ipython3
class Weights(nnx.Module):
@@ -352,9 +372,9 @@ print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
```
-Here we added a `sharding` metadata to the `Param` variables and used `transform_metadata` to update the `sharding` metadata to reflect the axes changes, specifically we can see that first axis `b` was removed from the `sharding` metadata when inside `vmap` and then added back when outside `vmap`.
+Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.
-We can verify that this also works when Modules are created inside the transformation, the new `sharding` axes will be added to the Module's Variables outside the transformation, matching the axes of the transformed Variables.
+You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s.
```{code-cell} ipython3
@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})