diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 776f5c3d82..93f256016a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -23,7 +23,7 @@ repos:
hooks:
- id: check-toml
- id: trailing-whitespace
- exclude: ^docs*/.*\.md$
+ exclude: ^docs.*\.md$
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
diff --git a/docs_nnx/guides/checkpointing.ipynb b/docs_nnx/guides/checkpointing.ipynb
new file mode 100644
index 0000000000..100dcffc34
--- /dev/null
+++ b/docs_nnx/guides/checkpointing.ipynb
@@ -0,0 +1,382 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Save and load checkpoints\n",
+ "\n",
+ "Flax does not actively maintain a library for saving and loading model checkpoints to disk. We recommend you to use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it.\n",
+ "\n",
+ "This guide will cover a few user-side scenarios of saving and loading Flax NNX model checkpoints using Orbax. The Orbax API we use here are for demonstration purposes only - please check out the [Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for the most up-to-date recommended API.\n",
+ "\n",
+ "> Note: Flax's legacy `flax.training.checkpoints` package is deprecated. Its documentation resides in the [legacy documentation site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Setup\n",
+ "\n",
+ "We first set up a checkpoint directory and an example NNX model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from flax import nnx\n",
+ "import orbax.checkpoint as ocp\n",
+ "import jax\n",
+ "from jax import numpy as jnp\n",
+ "import numpy as np\n",
+ "\n",
+ "ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class TwoLayerMLP(nnx.Module):\n",
+ " def __init__(self, dim, rngs: nnx.Rngs):\n",
+ " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n",
+ " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " x = self.linear1(x)\n",
+ " return self.linear2(x)\n",
+ "\n",
+ "# Create this model and show we can run it\n",
+ "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
+ "x = jax.random.normal(jax.random.key(42), (3, 4))\n",
+ "assert model(x).shape == (3, 4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Save checkpoints\n",
+ "\n",
+ "JAX checkpointing libraries save [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of JAX arrays (or, \"tensors\" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data such as optimizer states.\n",
+ "\n",
+ "In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) and pick up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "_, state = nnx.split(model)\n",
+ "nnx.display(state)\n",
+ "\n",
+ "checkpointer = ocp.StandardCheckpointer()\n",
+ "checkpointer.save(ckpt_dir / 'state', state)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Restore checkpoints\n",
+ "\n",
+ "Note that you saved the checkpoint as a class of `nnx.State`, which is also nested with `nnx.VariableState` and `nnx.Param` Python classes. At restoration time, you need to have these classes ready in your runtime, and tell the checkpointing library to restore your pytree back to that structure.\n",
+ "\n",
+ "You can achive this by creating an abstract NNX model (without allocating any memory for arrays) and show its abstract variable state to the checkpointing library.\n",
+ "\n",
+ "Once you have the state, use `nnx.merge` to obtain your NNX model and use it as usual."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The abstract NNX state (all leaves are abstract arrays):\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "NNX State restored: \n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference\n",
+ "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
+ "graphdef, abstract_state = nnx.split(abstract_model)\n",
+ "print('The abstract NNX state (all leaves are abstract arrays):')\n",
+ "nnx.display(abstract_state)\n",
+ "\n",
+ "state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)\n",
+ "jax.tree.map(np.testing.assert_array_equal, state, state_restored)\n",
+ "print('NNX State restored: ')\n",
+ "nnx.display(state_restored)\n",
+ "\n",
+ "# The model is good to use as ever!\n",
+ "model = nnx.merge(graphdef, state_restored)\n",
+ "assert model(x).shape == (3, 4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Save and restore as pure dictionaries\n",
+ "\n",
+ "You might prefer to work with Python built-in container types when interacting with checkpoint libraries. In that case, you can use the `nnx.State.to_pure_dict` and `nnx.State.replace_by_pure_dict` API to convert an `nnx.State` to and from pure nested dictionaries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Save as pure dict\n",
+ "pure_dict_state = state.to_pure_dict()\n",
+ "nnx.display(pure_dict_state)\n",
+ "checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)\n",
+ "\n",
+ "# Restore as pure dict\n",
+ "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n",
+ "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
+ "graphdef, abstract_state = nnx.split(abstract_model)\n",
+ "abstract_state.replace_by_pure_dict(restored_pure_dict)\n",
+ "model = nnx.merge(graphdef, abstract_state)\n",
+ "assert model(x).shape == (3, 4) # the model still works!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Restore when checkpoint structures differ\n",
+ "\n",
+ "The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer matches with your current model code. Check out this simple example below.\n",
+ "\n",
+ "This pattern also works if you saved the checkpoint as `nnx.State` instead of pure dictionary. Check out the [checkpoint surgery guide](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) for a code example. The only difference is you need to re-process your raw diciontary a bit before calling `restore_from_pure_dict`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "class ModifiedTwoLayerMLP(nnx.Module):\n",
+ " \"\"\"A modified version of TwoLayerMLP, which requires bias arrays.\"\"\"\n",
+ " def __init__(self, dim, rngs: nnx.Rngs):\n",
+ " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n",
+ " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " x = self.linear1(x)\n",
+ " return self.linear2(x)\n",
+ "\n",
+ "# Accomodate your old checkpoint to the new code...\n",
+ "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n",
+ "restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))\n",
+ "restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))\n",
+ "\n",
+ "# Same restore code as above\n",
+ "abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
+ "graphdef, abstract_state = nnx.split(abstract_model)\n",
+ "abstract_state.replace_by_pure_dict(restored_pure_dict)\n",
+ "model = nnx.merge(graphdef, abstract_state)\n",
+ "assert model(x).shape == (3, 4) # the new model works!\n",
+ "\n",
+ "nnx.display(model.linear1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Multi-process checkpointing\n",
+ "\n",
+ "In multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out [this section of the Flax scale-up guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) to learn how to derive a sharding tree and use it to load your checkpoint."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Other checkpointing features\n",
+ "\n",
+ "This guide only use the simplest Orbax save/load API ([`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer)) to showcase all the tricks on the Flax modeling side. Feel free to use other tools or libraries as you see fit.\n",
+ "\n",
+ "Check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as:\n",
+ "\n",
+ "* [CheckpointManager](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps\n",
+ "\n",
+ "* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html)\n",
+ "\n",
+ "* [Transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): a way to modify pytree structure during loading time (not after loading time, as in this guide)"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs_nnx/guides/checkpointing.md b/docs_nnx/guides/checkpointing.md
new file mode 100644
index 0000000000..1ae5978b22
--- /dev/null
+++ b/docs_nnx/guides/checkpointing.md
@@ -0,0 +1,201 @@
+# Save and load checkpoints
+
+Flax does not actively maintain a library for saving and loading model checkpoints to disk. We recommend you to use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it.
+
+This guide will cover a few user-side scenarios of saving and loading Flax NNX model checkpoints using Orbax. The Orbax API we use here are for demonstration purposes only - please check out the [Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for the most up-to-date recommended API.
+
+> Note: Flax's legacy `flax.training.checkpoints` package is deprecated. Its documentation resides in the [legacy documentation site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html).
+
+### Setup
+
+We first set up a checkpoint directory and an example NNX model.
+
+
+```python
+from flax import nnx
+import orbax.checkpoint as ocp
+import jax
+from jax import numpy as jnp
+import numpy as np
+
+ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
+```
+
+
+```python
+class TwoLayerMLP(nnx.Module):
+ def __init__(self, dim, rngs: nnx.Rngs):
+ self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
+ self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
+
+ def __call__(self, x):
+ x = self.linear1(x)
+ return self.linear2(x)
+
+# Create this model and show we can run it
+model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
+x = jax.random.normal(jax.random.key(42), (3, 4))
+assert model(x).shape == (3, 4)
+```
+
+## Save checkpoints
+
+JAX checkpointing libraries save [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of JAX arrays (or, "tensors" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data such as optimizer states.
+
+In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) and pick up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State).
+
+
+```python
+_, state = nnx.split(model)
+nnx.display(state)
+
+checkpointer = ocp.StandardCheckpointer()
+checkpointer.save(ckpt_dir / 'state', state)
+```
+
+
+
+
+
+
+
+
+
+## Restore checkpoints
+
+Note that you saved the checkpoint as a class of `nnx.State`, which is also nested with `nnx.VariableState` and `nnx.Param` Python classes. At restoration time, you need to have these classes ready in your runtime, and tell the checkpointing library to restore your pytree back to that structure.
+
+You can achive this by creating an abstract NNX model (without allocating any memory for arrays) and show its abstract variable state to the checkpointing library.
+
+Once you have the state, use `nnx.merge` to obtain your NNX model and use it as usual.
+
+
+```python
+# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference
+abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
+graphdef, abstract_state = nnx.split(abstract_model)
+print('The abstract NNX state (all leaves are abstract arrays):')
+nnx.display(abstract_state)
+
+state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)
+jax.tree.map(np.testing.assert_array_equal, state, state_restored)
+print('NNX State restored: ')
+nnx.display(state_restored)
+
+# The model is good to use as ever!
+model = nnx.merge(graphdef, state_restored)
+assert model(x).shape == (3, 4)
+```
+
+ The abstract NNX state (all leaves are abstract arrays):
+
+
+
+
+
+
+ NNX State restored:
+
+
+ /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
+ warnings.warn(
+
+
+
+
+
+
+
+
+
+
+## Save and restore as pure dictionaries
+
+You might prefer to work with Python built-in container types when interacting with checkpoint libraries. In that case, you can use the `nnx.State.to_pure_dict` and `nnx.State.replace_by_pure_dict` API to convert an `nnx.State` to and from pure nested dictionaries.
+
+
+```python
+# Save as pure dict
+pure_dict_state = state.to_pure_dict()
+nnx.display(pure_dict_state)
+checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
+
+# Restore as pure dict
+restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
+abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
+graphdef, abstract_state = nnx.split(abstract_model)
+abstract_state.replace_by_pure_dict(restored_pure_dict)
+model = nnx.merge(graphdef, abstract_state)
+assert model(x).shape == (3, 4) # the model still works!
+```
+
+
+
+
+
+
+
+
+
+ WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
+
+
+## Restore when checkpoint structures differ
+
+The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer matches with your current model code. Check out this simple example below.
+
+This pattern also works if you saved the checkpoint as `nnx.State` instead of pure dictionary. Check out the [checkpoint surgery guide](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) for a code example. The only difference is you need to re-process your raw diciontary a bit before calling `restore_from_pure_dict`.
+
+
+```python
+class ModifiedTwoLayerMLP(nnx.Module):
+ """A modified version of TwoLayerMLP, which requires bias arrays."""
+ def __init__(self, dim, rngs: nnx.Rngs):
+ self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!
+ self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!
+
+ def __call__(self, x):
+ x = self.linear1(x)
+ return self.linear2(x)
+
+# Accomodate your old checkpoint to the new code...
+restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
+restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))
+restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))
+
+# Same restore code as above
+abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
+graphdef, abstract_state = nnx.split(abstract_model)
+abstract_state.replace_by_pure_dict(restored_pure_dict)
+model = nnx.merge(graphdef, abstract_state)
+assert model(x).shape == (3, 4) # the new model works!
+
+nnx.display(model.linear1)
+```
+
+ WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
+
+
+
+
+
+
+
+
+
+
+## Multi-process checkpointing
+
+In multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out [this section of the Flax scale-up guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) to learn how to derive a sharding tree and use it to load your checkpoint.
+
+## Other checkpointing features
+
+This guide only use the simplest Orbax save/load API ([`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer)) to showcase all the tricks on the Flax modeling side. Feel free to use other tools or libraries as you see fit.
+
+Check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as:
+
+* [CheckpointManager](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps
+
+* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html)
+
+* [Transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): a way to modify pytree structure during loading time (not after loading time, as in this guide)
diff --git a/docs_nnx/guides/surgery.ipynb b/docs_nnx/guides/surgery.ipynb
index 00a1839ec5..8213abea2a 100644
--- a/docs_nnx/guides/surgery.ipynb
+++ b/docs_nnx/guides/surgery.ipynb
@@ -221,15 +221,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "This will throw error: : 'layer1'\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/ivyzheng/envs/py310/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1401: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
- " warnings.warn(\n"
+ "This will throw error: : Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.\n"
]
}
],
@@ -267,45 +259,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "{'linear1': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n",
- " 'kernel': {'raw_value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n",
+ "{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n",
+ " 'kernel': {'value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n",
" [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n",
" [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n",
" [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)}},\n",
- " 'linear2': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n",
- " 'kernel': {'raw_value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n",
+ " 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n",
+ " 'kernel': {'value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n",
" [ 0.41914317, 0.84359694, -0.47937787, -0.49135214],\n",
" [-0.46072108, 0.4630125 , 0.39276958, -0.9441406 ],\n",
" [-0.6690758 , -0.18474789, -0.57622856, 0.4821079 ]], dtype=float32)}}}\n"
]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
+ " warnings.warn(\n"
+ ]
}
],
"source": [
- "def module_from_variables_dict(module_factory, variables, map_key_fn):\n",
- " if map_key_fn is None:\n",
- " map_key_fn = lambda path: path\n",
- " mdl = nnx.eval_shape(module_factory)\n",
- " graph_def, state = nnx.split(mdl)\n",
- " state = state.flat_state()\n",
- " for path, val in flax.traverse_util.flatten_dict(variables).items():\n",
- " mapped_path = map_key_fn(path)\n",
- " if mapped_path not in state:\n",
- " raise ValueError(f\"{mapped_path} doesn't exist in {state.keys()}\")\n",
- " state[mapped_path].value = val\n",
- " state = nnx.State.from_flat_path(state)\n",
- " return nnx.merge(graph_def, state)\n",
- "\n",
- "# Make your local change on the checkpoint.\n",
- "raw = checkpointer.restore('/tmp/nnx-surgery-state')\n",
- "pprint(raw)\n",
- "raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']\n",
- "del raw['linear1'], raw['linear2']\n",
- "\n",
- "restored_model = module_from_variables_dict(\n",
- " lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),\n",
- " raw,\n",
- " lambda path: path[:-1] if path[-1] == 'raw_value' else path\n",
- ")\n",
+ "def process_raw_dict(raw_state_dict):\n",
+ " flattened = nnx.traversals.flatten_mapping(raw_state_dict)\n",
+ " # Cut off the '.value' postfix on every leaf path.\n",
+ " flattened = {(path[:-1] if path[-1] == 'value' else path): value\n",
+ " for path, value in flattened.items()}\n",
+ " return nnx.traversals.unflatten_mapping(flattened)\n",
+ "\n",
+ "# Make your local change on the checkpoint dictionary.\n",
+ "raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')\n",
+ "pprint(raw_dict)\n",
+ "raw_dict['layer1'] = raw_dict.pop('linear1')\n",
+ "raw_dict['layer2'] = raw_dict.pop('linear2')\n",
+ "\n",
+ "# Fit it into the model state.\n",
+ "abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
+ "graph_def, state = nnx.split(abs_model)\n",
+ "state.replace_by_pure_dict(process_raw_dict(raw_dict))\n",
+ "restored_model = nnx.merge(graph_def, state)\n",
"\n",
"np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))"
]
@@ -339,9 +332,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Number of jax arrays in memory at start: 34\n",
- "Number of jax arrays in memory midway: 38 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n",
- "Number of jax arrays in memory at end: 36 (2 discarded - only lora_a & lora_b are used in model)\n"
+ "Number of jax arrays in memory at start: 38\n",
+ "Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n",
+ "Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)\n"
]
}
],
@@ -379,8 +372,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Number of jax arrays in memory at start: 40\n",
- "Number of jax arrays in memory at end: 42 (2 new created - lora_a and lora_b)\n"
+ "Number of jax arrays in memory at start: 44\n",
+ "Number of jax arrays in memory at end: 46 (2 new created - lora_a and lora_b)\n"
]
}
],
@@ -389,7 +382,7 @@
"old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"\n",
"# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n",
- "@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)\n",
+ "@nnx.jit(donate_argnums=0)\n",
"def partial_init(old_state, rngs):\n",
" model = TwoLayerMLP(4, rngs=rngs)\n",
" # Create a new state.\n",
@@ -404,6 +397,20 @@
"print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'\n",
" ' (2 new created - lora_a and lora_b)')"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
@@ -420,7 +427,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.14"
+ "version": "3.11.9"
}
},
"nbformat": 4,
diff --git a/docs_nnx/guides/surgery.md b/docs_nnx/guides/surgery.md
index e829f850ce..3c1aa786ae 100644
--- a/docs_nnx/guides/surgery.md
+++ b/docs_nnx/guides/surgery.md
@@ -152,31 +152,24 @@ except Exception as e:
But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition.
```{code-cell} ipython3
-def module_from_variables_dict(module_factory, variables, map_key_fn):
- if map_key_fn is None:
- map_key_fn = lambda path: path
- mdl = nnx.eval_shape(module_factory)
- graph_def, state = nnx.split(mdl)
- state = state.flat_state()
- for path, val in flax.traverse_util.flatten_dict(variables).items():
- mapped_path = map_key_fn(path)
- if mapped_path not in state:
- raise ValueError(f"{mapped_path} doesn't exist in {state.keys()}")
- state[mapped_path].value = val
- state = nnx.State.from_flat_path(state)
- return nnx.merge(graph_def, state)
-
-# Make your local change on the checkpoint.
-raw = checkpointer.restore('/tmp/nnx-surgery-state')
-pprint(raw)
-raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']
-del raw['linear1'], raw['linear2']
-
-restored_model = module_from_variables_dict(
- lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),
- raw,
- lambda path: path[:-1] if path[-1] == 'raw_value' else path
-)
+def process_raw_dict(raw_state_dict):
+ flattened = nnx.traversals.flatten_mapping(raw_state_dict)
+ # Cut off the '.value' postfix on every leaf path.
+ flattened = {(path[:-1] if path[-1] == 'value' else path): value
+ for path, value in flattened.items()}
+ return nnx.traversals.unflatten_mapping(flattened)
+
+# Make your local change on the checkpoint dictionary.
+raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')
+pprint(raw_dict)
+raw_dict['layer1'] = raw_dict.pop('linear1')
+raw_dict['layer2'] = raw_dict.pop('linear2')
+
+# Fit it into the model state.
+abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
+graph_def, state = nnx.split(abs_model)
+state.replace_by_pure_dict(process_raw_dict(raw_dict))
+restored_model = nnx.merge(graph_def, state)
np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
```
@@ -218,7 +211,7 @@ Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
-@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)
+@nnx.jit(donate_argnums=0)
def partial_init(old_state, rngs):
model = TwoLayerMLP(4, rngs=rngs)
# Create a new state.
@@ -233,3 +226,11 @@ good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
' (2 new created - lora_a and lora_b)')
```
+
+```{code-cell} ipython3
+
+```
+
+```{code-cell} ipython3
+
+```
diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py
index 1a35058b7a..dd6a18a56b 100644
--- a/flax/nnx/nn/linear.py
+++ b/flax/nnx/nn/linear.py
@@ -370,6 +370,7 @@ def __call__(self, inputs: Array) -> Array:
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
+ assert self.use_bias == (bias is not None)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
return y