From 2330f4a554d9191bf74ec0fe492bdf28d59ce150 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:45:57 +0000 Subject: [PATCH] Ugrade Flax NNX Checkpointing guide --- .../checkpointing-checkpoint.ipynb | 382 ++++++++++++++++++ docs_nnx/guides/checkpointing.ipynb | 112 +++-- docs_nnx/guides/checkpointing.md | 93 +++-- 3 files changed, 523 insertions(+), 64 deletions(-) create mode 100644 docs_nnx/guides/.ipynb_checkpoints/checkpointing-checkpoint.ipynb diff --git a/docs_nnx/guides/.ipynb_checkpoints/checkpointing-checkpoint.ipynb b/docs_nnx/guides/.ipynb_checkpoints/checkpointing-checkpoint.ipynb new file mode 100644 index 0000000000..100dcffc34 --- /dev/null +++ b/docs_nnx/guides/.ipynb_checkpoints/checkpointing-checkpoint.ipynb @@ -0,0 +1,382 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save and load checkpoints\n", + "\n", + "Flax does not actively maintain a library for saving and loading model checkpoints to disk. We recommend you to use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it.\n", + "\n", + "This guide will cover a few user-side scenarios of saving and loading Flax NNX model checkpoints using Orbax. The Orbax API we use here are for demonstration purposes only - please check out the [Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for the most up-to-date recommended API.\n", + "\n", + "> Note: Flax's legacy `flax.training.checkpoints` package is deprecated. Its documentation resides in the [legacy documentation site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "We first set up a checkpoint directory and an example NNX model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import orbax.checkpoint as ocp\n", + "import jax\n", + "from jax import numpy as jnp\n", + "import numpy as np\n", + "\n", + "ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class TwoLayerMLP(nnx.Module):\n", + " def __init__(self, dim, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n", + " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n", + "\n", + " def __call__(self, x):\n", + " x = self.linear1(x)\n", + " return self.linear2(x)\n", + "\n", + "# Create this model and show we can run it\n", + "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", + "x = jax.random.normal(jax.random.key(42), (3, 4))\n", + "assert model(x).shape == (3, 4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save checkpoints\n", + "\n", + "JAX checkpointing libraries save [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of JAX arrays (or, \"tensors\" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data such as optimizer states.\n", + "\n", + "In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) and pick up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_, state = nnx.split(model)\n", + "nnx.display(state)\n", + "\n", + "checkpointer = ocp.StandardCheckpointer()\n", + "checkpointer.save(ckpt_dir / 'state', state)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Restore checkpoints\n", + "\n", + "Note that you saved the checkpoint as a class of `nnx.State`, which is also nested with `nnx.VariableState` and `nnx.Param` Python classes. At restoration time, you need to have these classes ready in your runtime, and tell the checkpointing library to restore your pytree back to that structure.\n", + "\n", + "You can achive this by creating an abstract NNX model (without allocating any memory for arrays) and show its abstract variable state to the checkpointing library.\n", + "\n", + "Once you have the state, use `nnx.merge` to obtain your NNX model and use it as usual." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The abstract NNX state (all leaves are abstract arrays):\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NNX State restored: \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference\n", + "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", + "graphdef, abstract_state = nnx.split(abstract_model)\n", + "print('The abstract NNX state (all leaves are abstract arrays):')\n", + "nnx.display(abstract_state)\n", + "\n", + "state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)\n", + "jax.tree.map(np.testing.assert_array_equal, state, state_restored)\n", + "print('NNX State restored: ')\n", + "nnx.display(state_restored)\n", + "\n", + "# The model is good to use as ever!\n", + "model = nnx.merge(graphdef, state_restored)\n", + "assert model(x).shape == (3, 4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save and restore as pure dictionaries\n", + "\n", + "You might prefer to work with Python built-in container types when interacting with checkpoint libraries. In that case, you can use the `nnx.State.to_pure_dict` and `nnx.State.replace_by_pure_dict` API to convert an `nnx.State` to and from pure nested dictionaries." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n" + ] + } + ], + "source": [ + "# Save as pure dict\n", + "pure_dict_state = state.to_pure_dict()\n", + "nnx.display(pure_dict_state)\n", + "checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)\n", + "\n", + "# Restore as pure dict\n", + "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", + "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", + "graphdef, abstract_state = nnx.split(abstract_model)\n", + "abstract_state.replace_by_pure_dict(restored_pure_dict)\n", + "model = nnx.merge(graphdef, abstract_state)\n", + "assert model(x).shape == (3, 4) # the model still works!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Restore when checkpoint structures differ\n", + "\n", + "The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer matches with your current model code. Check out this simple example below.\n", + "\n", + "This pattern also works if you saved the checkpoint as `nnx.State` instead of pure dictionary. Check out the [checkpoint surgery guide](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) for a code example. The only difference is you need to re-process your raw diciontary a bit before calling `restore_from_pure_dict`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class ModifiedTwoLayerMLP(nnx.Module):\n", + " \"\"\"A modified version of TwoLayerMLP, which requires bias arrays.\"\"\"\n", + " def __init__(self, dim, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n", + " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n", + "\n", + " def __call__(self, x):\n", + " x = self.linear1(x)\n", + " return self.linear2(x)\n", + "\n", + "# Accomodate your old checkpoint to the new code...\n", + "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", + "restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))\n", + "restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))\n", + "\n", + "# Same restore code as above\n", + "abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", + "graphdef, abstract_state = nnx.split(abstract_model)\n", + "abstract_state.replace_by_pure_dict(restored_pure_dict)\n", + "model = nnx.merge(graphdef, abstract_state)\n", + "assert model(x).shape == (3, 4) # the new model works!\n", + "\n", + "nnx.display(model.linear1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-process checkpointing\n", + "\n", + "In multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out [this section of the Flax scale-up guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) to learn how to derive a sharding tree and use it to load your checkpoint." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Other checkpointing features\n", + "\n", + "This guide only use the simplest Orbax save/load API ([`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer)) to showcase all the tricks on the Flax modeling side. Feel free to use other tools or libraries as you see fit.\n", + "\n", + "Check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as:\n", + "\n", + "* [CheckpointManager](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps\n", + "\n", + "* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html)\n", + "\n", + "* [Transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): a way to modify pytree structure during loading time (not after loading time, as in this guide)" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs_nnx/guides/checkpointing.ipynb b/docs_nnx/guides/checkpointing.ipynb index 100dcffc34..b86be694f5 100644 --- a/docs_nnx/guides/checkpointing.ipynb +++ b/docs_nnx/guides/checkpointing.ipynb @@ -6,11 +6,18 @@ "source": [ "# Save and load checkpoints\n", "\n", - "Flax does not actively maintain a library for saving and loading model checkpoints to disk. We recommend you to use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it.\n", + "This guide demonstrates how to save and load Flax NNX model checkpoints with [Orbax](https://orbax.readthedocs.io/). You will learn how to:\n", "\n", - "This guide will cover a few user-side scenarios of saving and loading Flax NNX model checkpoints using Orbax. The Orbax API we use here are for demonstration purposes only - please check out the [Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for the most up-to-date recommended API.\n", + "* Save checkpoints.\n", + "* Restore checkpoints.\n", + "* Restore checkpoints if checkpoint structures differ. \n", + "* Perform multi-process checkpointing. \n", "\n", - "> Note: Flax's legacy `flax.training.checkpoints` package is deprecated. Its documentation resides in the [legacy documentation site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html)." + "The Orbax API examples used throughout the guide are for demonstration purposes, and for the most up-to-date recommended APIs refer to the [Orbax website](https://orbax.readthedocs.io/).\n", + "\n", + "> **Note:** The Flax team recommends using [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for saving and loading checkpoints to disk, as we do not actively maintain a library for these functionalities.\n", + "\n", + "> **Note:** If you are looking for Flax Linen's legacy `flax.training.checkpoints` package, it was deprecated in 2023 in favor of Orbax. The documentation resides on the [Flax Linen site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html)." ] }, { @@ -19,7 +26,7 @@ "source": [ "### Setup\n", "\n", - "We first set up a checkpoint directory and an example NNX model." + "Import the necessary dependencies, set up a checkpoint directory and an example Flax NNX model - `TwoLayerMLP` - by subclassing `flax.nnx.Module`." ] }, { @@ -52,7 +59,7 @@ " x = self.linear1(x)\n", " return self.linear2(x)\n", "\n", - "# Create this model and show we can run it\n", + "# Instantiate the model and show we can run it.\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "x = jax.random.normal(jax.random.key(42), (3, 4))\n", "assert model(x).shape == (3, 4)" @@ -64,9 +71,9 @@ "source": [ "## Save checkpoints\n", "\n", - "JAX checkpointing libraries save [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of JAX arrays (or, \"tensors\" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data such as optimizer states.\n", + "JAX checkpointing libraries, such as Orbax, can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of [`jax.Array`s)](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) (or, \"tensors\" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data, such as optimizer states.\n", "\n", - "In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) and pick up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)." + "In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), and picking up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)." ] }, { @@ -111,13 +118,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Restore checkpoints\n", + "
\n", + "\n", + "\n", "\n", - "Note that you saved the checkpoint as a class of `nnx.State`, which is also nested with `nnx.VariableState` and `nnx.Param` Python classes. At restoration time, you need to have these classes ready in your runtime, and tell the checkpointing library to restore your pytree back to that structure.\n", + "
\n", "\n", - "You can achive this by creating an abstract NNX model (without allocating any memory for arrays) and show its abstract variable state to the checkpointing library.\n", "\n", - "Once you have the state, use `nnx.merge` to obtain your NNX model and use it as usual." + "## Restore checkpoints\n", + "\n", + "Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n", + "\n", + "At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:\n", + "- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.\n", + "- Once you have the state, use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to obtain your Flax NNX model, and use it as usual." ] }, { @@ -185,7 +199,7 @@ } ], "source": [ - "# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference\n", + "# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.\n", "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "graphdef, abstract_state = nnx.split(abstract_model)\n", "print('The abstract NNX state (all leaves are abstract arrays):')\n", @@ -196,7 +210,7 @@ "print('NNX State restored: ')\n", "nnx.display(state_restored)\n", "\n", - "# The model is good to use as ever!\n", + "# The model is now good to use!\n", "model = nnx.merge(graphdef, state_restored)\n", "assert model(x).shape == (3, 4)" ] @@ -205,9 +219,31 @@ "cell_type": "markdown", "metadata": {}, "source": [ + " The abstract NNX state (all leaves are abstract arrays):\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + " NNX State restored: \n", + "\n", + "\n", + " /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", + " warnings.warn(\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", "## Save and restore as pure dictionaries\n", "\n", - "You might prefer to work with Python built-in container types when interacting with checkpoint libraries. In that case, you can use the `nnx.State.to_pure_dict` and `nnx.State.replace_by_pure_dict` API to convert an `nnx.State` to and from pure nested dictionaries." + "When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the `nnx.State.to_pure_dict` and `nnx.State.replace_by_pure_dict` API to convert an `nnx.State` to and from pure nested dictionaries." ] }, { @@ -253,7 +289,7 @@ "nnx.display(pure_dict_state)\n", "checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)\n", "\n", - "# Restore as pure dict\n", + "# Restore as a pure dictionary.\n", "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "graphdef, abstract_state = nnx.split(abstract_model)\n", @@ -266,11 +302,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "
\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + " WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n", + "\n", + "\n", "## Restore when checkpoint structures differ\n", "\n", - "The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer matches with your current model code. Check out this simple example below.\n", + "The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below.\n", "\n", - "This pattern also works if you saved the checkpoint as `nnx.State` instead of pure dictionary. Check out the [checkpoint surgery guide](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) for a code example. The only difference is you need to re-process your raw diciontary a bit before calling `restore_from_pure_dict`." + "This pattern also works if you saved the checkpoint as an `nnx.State` instead of a pure dictionary. Check out the [Checkpoint surgery section](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) of the [Model Surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) guide for an example with code. The only difference is you need to reprocess your raw dictionary a bit before calling `restore_from_pure_dict`." ] }, { @@ -321,7 +367,7 @@ " x = self.linear1(x)\n", " return self.linear2(x)\n", "\n", - "# Accomodate your old checkpoint to the new code...\n", + "# Accommodate your old checkpoint to the new code.\n", "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", "restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))\n", "restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))\n", @@ -331,7 +377,7 @@ "graphdef, abstract_state = nnx.split(abstract_model)\n", "abstract_state.replace_by_pure_dict(restored_pure_dict)\n", "model = nnx.merge(graphdef, abstract_state)\n", - "assert model(x).shape == (3, 4) # the new model works!\n", + "assert model(x).shape == (3, 4) # The new model works!\n", "\n", "nnx.display(model.linear1)" ] @@ -340,9 +386,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ + " WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", "## Multi-process checkpointing\n", "\n", - "In multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out [this section of the Flax scale-up guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) to learn how to derive a sharding tree and use it to load your checkpoint." + "In a multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out [this section](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) of the Flax [Scale-up](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide to learn how to derive a sharding pytree and use it to load your checkpoint.\n", + "\n", + "> **Note:** JAX provides several ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). Check out JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html), [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html), [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), and [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html)." ] }, { @@ -351,19 +410,22 @@ "source": [ "## Other checkpointing features\n", "\n", - "This guide only use the simplest Orbax save/load API ([`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer)) to showcase all the tricks on the Flax modeling side. Feel free to use other tools or libraries as you see fit.\n", + "This guide only uses the simplest ([`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer)) API to show how to save and load on the Flax modeling side. Feel free to use other tools or libraries as you see fit.\n", "\n", - "Check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as:\n", + "In addition, check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as:\n", "\n", - "* [CheckpointManager](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps\n", + "* [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps.\n", "\n", - "* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html)\n", + "* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html).\n", "\n", - "* [Transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): a way to modify pytree structure during loading time (not after loading time, as in this guide)" + "* [Orbax transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): A way to modify pytree structure during loading time (instead of after loading time, which is demonstrated in this guide)." ] } ], "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, "language_info": { "codemirror_mode": { "name": "ipython", diff --git a/docs_nnx/guides/checkpointing.md b/docs_nnx/guides/checkpointing.md index 1ae5978b22..22d1b62676 100644 --- a/docs_nnx/guides/checkpointing.md +++ b/docs_nnx/guides/checkpointing.md @@ -1,17 +1,35 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + # Save and load checkpoints -Flax does not actively maintain a library for saving and loading model checkpoints to disk. We recommend you to use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it. +This guide demonstrates how to save and load Flax NNX model checkpoints with [Orbax](https://orbax.readthedocs.io/). You will learn how to: -This guide will cover a few user-side scenarios of saving and loading Flax NNX model checkpoints using Orbax. The Orbax API we use here are for demonstration purposes only - please check out the [Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for the most up-to-date recommended API. +* Save checkpoints. +* Restore checkpoints. +* Restore checkpoints if checkpoint structures differ. +* Perform multi-process checkpointing. -> Note: Flax's legacy `flax.training.checkpoints` package is deprecated. Its documentation resides in the [legacy documentation site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html). +The Orbax API examples used throughout the guide are for demonstration purposes, and for the most up-to-date recommended APIs refer to the [Orbax website](https://orbax.readthedocs.io/). -### Setup +> **Note:** The Flax team recommends using [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for saving and loading checkpoints to disk, as we do not actively maintain a library for these functionalities. + +> **Note:** If you are looking for Flax Linen's legacy `flax.training.checkpoints` package, it was deprecated in 2023 in favor of Orbax. The documentation resides on the [Flax Linen site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html). + ++++ -We first set up a checkpoint directory and an example NNX model. +### Setup +Import the necessary dependencies, set up a checkpoint directory and an example Flax NNX model - `TwoLayerMLP` - by subclassing `flax.nnx.Module`. -```python +```{code-cell} ipython3 from flax import nnx import orbax.checkpoint as ocp import jax @@ -21,8 +39,7 @@ import numpy as np ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') ``` - -```python +```{code-cell} ipython3 class TwoLayerMLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False) @@ -32,7 +49,7 @@ class TwoLayerMLP(nnx.Module): x = self.linear1(x) return self.linear2(x) -# Create this model and show we can run it +# Instantiate the model and show we can run it. model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) assert model(x).shape == (3, 4) @@ -40,12 +57,11 @@ assert model(x).shape == (3, 4) ## Save checkpoints -JAX checkpointing libraries save [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of JAX arrays (or, "tensors" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data such as optimizer states. +JAX checkpointing libraries, such as Orbax, can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of [`jax.Array`s)](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) (or, "tensors" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data, such as optimizer states. -In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) and pick up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State). +In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), and picking up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State). - -```python +```{code-cell} ipython3 _, state = nnx.split(model) nnx.display(state) @@ -53,7 +69,6 @@ checkpointer = ocp.StandardCheckpointer() checkpointer.save(ckpt_dir / 'state', state) ``` -
@@ -63,15 +78,14 @@ checkpointer.save(ckpt_dir / 'state', state) ## Restore checkpoints -Note that you saved the checkpoint as a class of `nnx.State`, which is also nested with `nnx.VariableState` and `nnx.Param` Python classes. At restoration time, you need to have these classes ready in your runtime, and tell the checkpointing library to restore your pytree back to that structure. +Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes. -You can achive this by creating an abstract NNX model (without allocating any memory for arrays) and show its abstract variable state to the checkpointing library. +At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows: +- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library. +- Once you have the state, use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to obtain your Flax NNX model, and use it as usual. -Once you have the state, use `nnx.merge` to obtain your NNX model and use it as usual. - - -```python -# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference +```{code-cell} ipython3 +# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference. abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) graphdef, abstract_state = nnx.split(abstract_model) print('The abstract NNX state (all leaves are abstract arrays):') @@ -82,7 +96,7 @@ jax.tree.map(np.testing.assert_array_equal, state, state_restored) print('NNX State restored: ') nnx.display(state_restored) -# The model is good to use as ever! +# The model is now good to use! model = nnx.merge(graphdef, state_restored) assert model(x).shape == (3, 4) ``` @@ -111,16 +125,15 @@ assert model(x).shape == (3, 4) ## Save and restore as pure dictionaries -You might prefer to work with Python built-in container types when interacting with checkpoint libraries. In that case, you can use the `nnx.State.to_pure_dict` and `nnx.State.replace_by_pure_dict` API to convert an `nnx.State` to and from pure nested dictionaries. +When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the `nnx.State.to_pure_dict` and `nnx.State.replace_by_pure_dict` API to convert an `nnx.State` to and from pure nested dictionaries. - -```python +```{code-cell} ipython3 # Save as pure dict pure_dict_state = state.to_pure_dict() nnx.display(pure_dict_state) checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state) -# Restore as pure dict +# Restore as a pure dictionary. restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) graphdef, abstract_state = nnx.split(abstract_model) @@ -129,7 +142,6 @@ model = nnx.merge(graphdef, abstract_state) assert model(x).shape == (3, 4) # the model still works! ``` -
@@ -142,12 +154,11 @@ assert model(x).shape == (3, 4) # the model still works! ## Restore when checkpoint structures differ -The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer matches with your current model code. Check out this simple example below. +The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below. -This pattern also works if you saved the checkpoint as `nnx.State` instead of pure dictionary. Check out the [checkpoint surgery guide](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) for a code example. The only difference is you need to re-process your raw diciontary a bit before calling `restore_from_pure_dict`. +This pattern also works if you saved the checkpoint as an `nnx.State` instead of a pure dictionary. Check out the [Checkpoint surgery section](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) of the [Model Surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) guide for an example with code. The only difference is you need to reprocess your raw dictionary a bit before calling `restore_from_pure_dict`. - -```python +```{code-cell} ipython3 class ModifiedTwoLayerMLP(nnx.Module): """A modified version of TwoLayerMLP, which requires bias arrays.""" def __init__(self, dim, rngs: nnx.Rngs): @@ -158,7 +169,7 @@ class ModifiedTwoLayerMLP(nnx.Module): x = self.linear1(x) return self.linear2(x) -# Accomodate your old checkpoint to the new code... +# Accommodate your old checkpoint to the new code. restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') restored_pure_dict['linear1']['bias'] = jnp.zeros((4,)) restored_pure_dict['linear2']['bias'] = jnp.zeros((4,)) @@ -168,7 +179,7 @@ abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)) graphdef, abstract_state = nnx.split(abstract_model) abstract_state.replace_by_pure_dict(restored_pure_dict) model = nnx.merge(graphdef, abstract_state) -assert model(x).shape == (3, 4) # the new model works! +assert model(x).shape == (3, 4) # The new model works! nnx.display(model.linear1) ``` @@ -186,16 +197,20 @@ nnx.display(model.linear1) ## Multi-process checkpointing -In multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out [this section of the Flax scale-up guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) to learn how to derive a sharding tree and use it to load your checkpoint. +In a multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out [this section](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) of the Flax [Scale-up](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide to learn how to derive a sharding pytree and use it to load your checkpoint. + +> **Note:** JAX provides several ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). Check out JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html), [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html), [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), and [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html). + ++++ ## Other checkpointing features -This guide only use the simplest Orbax save/load API ([`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer)) to showcase all the tricks on the Flax modeling side. Feel free to use other tools or libraries as you see fit. +This guide only uses the simplest ([`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer)) API to show how to save and load on the Flax modeling side. Feel free to use other tools or libraries as you see fit. -Check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as: +In addition, check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as: -* [CheckpointManager](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps +* [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps. -* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html) +* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html). -* [Transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): a way to modify pytree structure during loading time (not after loading time, as in this guide) +* [Orbax transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): A way to modify pytree structure during loading time (instead of after loading time, which is demonstrated in this guide).