diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index 2cc10b0fdf..7e240ae7e1 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -105,7 +105,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -199,7 +211,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -258,7 +282,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -365,22 +401,26 @@ " using the optimizer alone.\n", "\n", "#### Scan over layers\n", - "Next lets take a look at a different example using `nnx.vmap` to create an\n", - "`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the\n", - "input (scan over layers). \n", + "Next lets take a look at a different example, which uses\n", + "[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap)\n", + "to create a stack of multiple MLP layers and\n", + "[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)\n", + "to iteratively apply each layer of the stack to the input.\n", "\n", "Notice the following:\n", - "1. The `create_model` function creates a (single) `MLP` object that is lifted by\n", - " `nnx.vmap` to have an additional dimension of size `axis_size`.\n", - "2. The `forward` function indexes the `MLP` object's state to get a different set of\n", - " parameters at each step.\n", - "3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and \n", - "`Dropout` layers from within `forward` to the `model` reference outside." + "1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys\n", + " and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n", + "2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`.\n", + "3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is\n", + " more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output,\n", + " and the position of the carry.\n", + "4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated \n", + " by `nnx.scan`." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -393,7 +433,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -404,15 +456,14 @@ } ], "source": [ - "from functools import partial\n", - "\n", - "@partial(nnx.vmap, axis_size=5)\n", - "def create_model(rngs: nnx.Rngs):\n", - " return MLP(10, 32, 10, rngs=rngs)\n", + "@nnx.vmap(in_axes=0, out_axes=0)\n", + "def create_model(key: jax.Array):\n", + " return MLP(10, 32, 10, rngs=nnx.Rngs(key))\n", "\n", - "model = create_model(nnx.Rngs(0))\n", + "keys = jax.random.split(jax.random.key(0), 5)\n", + "model = create_model(keys)\n", "\n", - "@nnx.scan\n", + "@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0)\n", "def forward(x, model: MLP):\n", " x = model(x)\n", " return x, None\n", @@ -457,7 +508,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -506,7 +569,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -518,7 +593,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -626,7 +701,31 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -638,7 +737,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index 70b3ff7540..f90468c3a9 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -219,28 +219,31 @@ Theres a couple of things happening in this example that are worth mentioning: using the optimizer alone. #### Scan over layers -Next lets take a look at a different example using `nnx.vmap` to create an -`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the -input (scan over layers). +Next lets take a look at a different example, which uses +[nnx.vmap](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) +to create a stack of multiple MLP layers and +[nnx.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) +to iteratively apply each layer of the stack to the input. Notice the following: -1. The `create_model` function creates a (single) `MLP` object that is lifted by - `nnx.vmap` to have an additional dimension of size `axis_size`. -2. The `forward` function indexes the `MLP` object's state to get a different set of - parameters at each step. -3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and -`Dropout` layers from within `forward` to the `model` reference outside. +1. The `create_model` function takes in a key and returns an `MLP` object, since we create 5 keys + and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created. +2. We use `nnx.scan` to iteratively apply each `MLP` in the stack to the input `x`. +3. The `nnx.scan` API (consciously) deviates from `jax.lax.scan` and instead mimicks `vmap` which is + more expressive. `nnx.scan` allows specifying multiple inputs, the scan axes of each input/output, + and the position of the carry. +4. State updates for the `BatchNorm` and `Dropout` layers are automatically propagated + by `nnx.scan`. ```{code-cell} ipython3 -from functools import partial +@nnx.vmap(in_axes=0, out_axes=0) +def create_model(key: jax.Array): + return MLP(10, 32, 10, rngs=nnx.Rngs(key)) -@partial(nnx.vmap, axis_size=5) -def create_model(rngs: nnx.Rngs): - return MLP(10, 32, 10, rngs=rngs) +keys = jax.random.split(jax.random.key(0), 5) +model = create_model(keys) -model = create_model(nnx.Rngs(0)) - -@nnx.scan +@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0) def forward(x, model: MLP): x = model(x) return x, None