diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb
index 2cc10b0fdf..7e240ae7e1 100644
--- a/docs/nnx/nnx_basics.ipynb
+++ b/docs/nnx/nnx_basics.ipynb
@@ -105,7 +105,19 @@
{
"data": {
"text/html": [
- "
(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -199,7 +211,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -258,7 +282,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -365,22 +401,26 @@
" using the optimizer alone.\n",
"\n",
"#### Scan over layers\n",
- "Next lets take a look at a different example using `nnx.vmap` to create an\n",
- "`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the\n",
- "input (scan over layers). \n",
+ "Next lets take a look at a different example, which uses\n",
+ "[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)\n",
+ "to create a stack of multiple MLP layers and\n",
+ "[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)\n",
+ "to iteratively apply each layer of the stack to the input.\n",
"\n",
"Notice the following:\n",
- "1. The `create_model` function creates a (single) `MLP` object that is lifted by\n",
- " `nnx.vmap` to have an additional dimension of size `axis_size`.\n",
- "2. The `forward` function indexes the `MLP` object's state to get a different set of\n",
- " parameters at each step.\n",
- "3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and \n",
- "`Dropout` layers from within `forward` to the `model` reference outside."
+ "1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys\n",
+ " and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n",
+ "2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.\n",
+ "3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is\n",
+ " more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,\n",
+ " and the position of the carry.\n",
+ "4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated \n",
+ " by `nnx.scan`."
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -393,7 +433,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -404,15 +456,14 @@
}
],
"source": [
- "from functools import partial\n",
- "\n",
- "@partial(nnx.vmap, axis_size=5)\n",
- "def create_model(rngs: nnx.Rngs):\n",
- " return MLP(10, 32, 10, rngs=rngs)\n",
+ "@nnx.vmap(in_axes=0, out_axes=0)\n",
+ "def create_model(key: jax.Array):\n",
+ " return MLP(10, 32, 10, rngs=nnx.Rngs(key))\n",
"\n",
- "model = create_model(nnx.Rngs(0))\n",
+ "keys = jax.random.split(jax.random.key(0), 5)\n",
+ "model = create_model(keys)\n",
"\n",
- "@nnx.scan\n",
+ "@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)\n",
"def forward(x, model: MLP):\n",
" x = model(x)\n",
" return x, None\n",
@@ -457,7 +508,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -506,7 +569,19 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -518,7 +593,7 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
],
"text/plain": [
""
@@ -626,7 +701,31 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
],
"text/plain": [
""
@@ -638,7 +737,7 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "
"
],
"text/plain": [
""
diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md
index 70b3ff7540..f90468c3a9 100644
--- a/docs/nnx/nnx_basics.md
+++ b/docs/nnx/nnx_basics.md
@@ -219,28 +219,31 @@ Theres a couple of things happening in this example that are worth mentioning:
using the optimizer alone.
#### Scan over layers
-Next lets take a look at a different example using `nnx.vmap` to create an
-`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the
-input (scan over layers).
+Next lets take a look at a different example, which uses
+[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)
+to create a stack of multiple MLP layers and
+[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)
+to iteratively apply each layer of the stack to the input.
Notice the following:
-1. The `create_model` function creates a (single) `MLP` object that is lifted by
- `nnx.vmap` to have an additional dimension of size `axis_size`.
-2. The `forward` function indexes the `MLP` object's state to get a different set of
- parameters at each step.
-3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and
-`Dropout` layers from within `forward` to the `model` reference outside.
+1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys
+ and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.
+2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.
+3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is
+ more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,
+ and the position of the carry.
+4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated
+ by `nnx.scan`.
```{code-cell} ipython3
-from functools import partial
+@nnx.vmap(in_axes=0, out_axes=0)
+def create_model(key: jax.Array):
+ return MLP(10, 32, 10, rngs=nnx.Rngs(key))
-@partial(nnx.vmap, axis_size=5)
-def create_model(rngs: nnx.Rngs):
- return MLP(10, 32, 10, rngs=rngs)
+keys = jax.random.split(jax.random.key(0), 5)
+model = create_model(keys)
-model = create_model(nnx.Rngs(0))
-
-@nnx.scan
+@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)
def forward(x, model: MLP):
x = model(x)
return x, None