From 0be4e14303d8a3b8f3671f9e6a865476ffac2c21 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 3 Sep 2024 16:25:11 +0100 Subject: [PATCH] [nnx] improve landing page and nnx_basics messaging --- docs/api_reference/flax.nnx/experimental.rst | 8 -- docs/index.rst | 3 +- docs/nnx/filters_guide.ipynb | 14 +-- docs/nnx/filters_guide.md | 14 +-- docs/nnx/index.rst | 37 ++++--- docs/nnx/mnist_tutorial.ipynb | 23 +++-- docs/nnx/mnist_tutorial.md | 23 +++-- docs/nnx/nnx_basics.ipynb | 31 +++--- docs/nnx/nnx_basics.md | 31 +++--- docs/nnx/surgery.ipynb | 19 ++-- docs/nnx/surgery.md | 103 +++++-------------- docs/nnx/transforms.rst | 31 +++--- 12 files changed, 148 insertions(+), 189 deletions(-) delete mode 100644 docs/api_reference/flax.nnx/experimental.rst diff --git a/docs/api_reference/flax.nnx/experimental.rst b/docs/api_reference/flax.nnx/experimental.rst deleted file mode 100644 index 73140d9bd2..0000000000 --- a/docs/api_reference/flax.nnx/experimental.rst +++ /dev/null @@ -1,8 +0,0 @@ -experimental ------------------------- - -.. automodule:: flax.nnx.experimental -.. currentmodule:: flax.nnx.experimental - -.. autoclass:: StateAxes -.. autofunction:: vmap \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 75f5d985fe..ce04817a65 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -325,6 +325,5 @@ Notable examples in Flax include: developer_notes/index philosophy contributing - experimental api_reference/index - NNX + Flax NNX diff --git a/docs/nnx/filters_guide.ipynb b/docs/nnx/filters_guide.ipynb index aa307e1b23..21591226ac 100644 --- a/docs/nnx/filters_guide.ipynb +++ b/docs/nnx/filters_guide.ipynb @@ -7,8 +7,10 @@ "source": [ "# Using Filters\n", "\n", - "Filters are used extensively in NNX as a way to create `State` groups in APIs\n", - "such as `nnx.split`, `nnx.state`, and many of the NNX transforms. For example:" + "> **Attention**: This page relates to the new Flax NNX API.\n", + "\n", + "Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n", + "such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:" ] }, { @@ -116,8 +118,8 @@ "metadata": {}, "source": [ "Such function matches any value that is an instance of `Param` or any value that has a \n", - "`type` attribute that is a subclass of `Param`. Internally NNX uses `OfType` which defines \n", - "a callable of this form for a given type:" + "`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n", + "defines a callable of this form for a given type:" ] }, { @@ -149,11 +151,11 @@ "source": [ "## The Filter DSL\n", "\n", - "To avoid users having to create these functions, NNX exposes a small DSL, formalized \n", + "To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized \n", "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, \n", "tuples/lists, etc, and converts them to the appropriate predicate internally.\n", "\n", - "Here is a list of all the callable Filters included in NNX and their DSL literals \n", + "Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n", "(when available):\n", "\n", "\n", diff --git a/docs/nnx/filters_guide.md b/docs/nnx/filters_guide.md index c637457907..84bbe3fa7f 100644 --- a/docs/nnx/filters_guide.md +++ b/docs/nnx/filters_guide.md @@ -10,8 +10,10 @@ jupytext: # Using Filters -Filters are used extensively in NNX as a way to create `State` groups in APIs -such as `nnx.split`, `nnx.state`, and many of the NNX transforms. For example: +> **Attention**: This page relates to the new Flax NNX API. + +Filters are used extensively in Flax NNX as a way to create `State` groups in APIs +such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example: ```{code-cell} ipython3 from flax import nnx @@ -63,8 +65,8 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` Such function matches any value that is an instance of `Param` or any value that has a -`type` attribute that is a subclass of `Param`. Internally NNX uses `OfType` which defines -a callable of this form for a given type: +`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which +defines a callable of this form for a given type: ```{code-cell} ipython3 is_param = nnx.OfType(nnx.Param) @@ -75,11 +77,11 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ## The Filter DSL -To avoid users having to create these functions, NNX exposes a small DSL, formalized +To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally. -Here is a list of all the callable Filters included in NNX and their DSL literals +Here is a list of all the callable Filters included in Flax NNX and their DSL literals (when available): diff --git a/docs/nnx/index.rst b/docs/nnx/index.rst index 1b6067b32f..bbca410aaf 100644 --- a/docs/nnx/index.rst +++ b/docs/nnx/index.rst @@ -1,5 +1,5 @@ -NNX +Flax NNX ======== .. div:: sd-text-left sd-font-italic @@ -8,11 +8,15 @@ NNX ---- -NNX is a new Flax API that is designed to make it easier to create, inspect, debug, -and analyze neural networks in JAX. It achieves this by adding first class support +Flax NNX is a new simplified API that is designed to make it easier to create, inspect, +debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics, allowing users to express their models using regular -Python objects. NNX takes years of feedback from Linen and brings to Flax a simpler -and more user-friendly experience. +Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, it takes years of +experience to bring a simpler and more user-friendly experience. + +.. note:: + Flax Linen is not going to be deprecated in the near future as most of our users still + rely on this API, however new users are encouraged to use Flax NNX. Features ^^^^^^^^^ @@ -29,7 +33,7 @@ Features .. div:: sd-font-normal - NNX supports the use of regular Python objects, providing an intuitive + Flax NNX supports the use of regular Python objects, providing an intuitive and predictable development experience. .. grid-item:: @@ -42,33 +46,34 @@ Features .. div:: sd-font-normal - NNX relies on Python's object model, which results in simplicity for + Flax NNX relies on Python's object model, which results in simplicity for the user and increases development speed. .. grid-item:: :columns: 12 12 12 6 - .. card:: Streamlined + .. card:: Expressive :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - NNX integrates user feedback and hands-on experience with Linen - into a new simplified API. + Flax NNX allows fine-grained control of the model's state via + its `Filter `__ + system. .. grid-item:: :columns: 12 12 12 6 - .. card:: Compatible + .. card:: Familiar :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - NNX makes it very easy to integrate objects with regular JAX code + Flax NNX makes it very easy to integrate objects with regular JAX code via the `Functional API `__. Basic usage @@ -114,7 +119,7 @@ Basic usage Installation ^^^^^^^^^^^^ -Install NNX via pip: +Install via pip: .. code-block:: bash @@ -137,7 +142,7 @@ Learn more .. grid-item:: :columns: 6 6 6 4 - .. card:: :material-regular:`rocket_launch;2em` NNX Basics + .. card:: :material-regular:`rocket_launch;2em` Flax NNX Basics :class-card: sd-text-black sd-bg-light :link: nnx_basics.html @@ -151,14 +156,14 @@ Learn more .. grid-item:: :columns: 6 6 6 4 - .. card:: :material-regular:`sync_alt;2em` NNX vs JAX Transformations + .. card:: :material-regular:`sync_alt;2em` Flax vs JAX Transformations :class-card: sd-text-black sd-bg-light :link: transforms.html .. grid-item:: :columns: 6 6 6 4 - .. card:: :material-regular:`transform;2em` Haiku and Linen vs NNX + .. card:: :material-regular:`transform;2em` Haiku and Flax Linen vs Flax NNX :class-card: sd-text-black sd-bg-light :link: haiku_linen_vs_nnx.html diff --git a/docs/nnx/mnist_tutorial.ipynb b/docs/nnx/mnist_tutorial.ipynb index 1e91a1877e..83c8f91203 100644 --- a/docs/nnx/mnist_tutorial.ipynb +++ b/docs/nnx/mnist_tutorial.ipynb @@ -10,8 +10,8 @@ "\n", "# MNIST Tutorial\n", "\n", - "Welcome to NNX! This tutorial will guide you through building and training a simple convolutional \n", - "neural network (CNN) on the MNIST dataset using the NNX API. NNX is a Python neural network library\n", + "Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional \n", + "neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library\n", "built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within \n", "[Flax](https://github.com/google/flax)." ] @@ -21,9 +21,10 @@ "id": "1", "metadata": {}, "source": [ - "## 1. Install NNX\n", + "## 1. Install Flax\n", "\n", - "Since NNX is under active development, we recommend using the latest version from the Flax GitHub repository:" + "If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the \n", + "following cell:" ] }, { @@ -37,7 +38,7 @@ }, "outputs": [], "source": [ - "# !pip install git+https://github.com/google/flax.git" + "# !pip install flax" ] }, { @@ -109,9 +110,9 @@ "id": "5", "metadata": {}, "source": [ - "## 3. Define the Network with NNX\n", + "## 3. Define the Network with Flax NNX\n", "\n", - "Create a convolutional neural network with NNX by subclassing `nnx.Module`." + "Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`." ] }, { @@ -134,7 +135,7 @@ } ], "source": [ - "from flax import nnx # NNX API\n", + "from flax import nnx # Flax NNX API\n", "from functools import partial\n", "\n", "class CNN(nnx.Module):\n", @@ -204,7 +205,7 @@ "source": [ "## 4. Create Optimizer and Metrics\n", "\n", - "In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." + "In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." ] }, { @@ -287,9 +288,9 @@ "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", "[XLA](https://www.tensorflow.org/xla), optimizing performance on \n", "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", - "except it can transforms functions that contain NNX objects as inputs and outputs.\n", + "except it can transforms functions that contain Flax NNX objects as inputs and outputs.\n", "\n", - "**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because NNX transforms respect reference semantics for NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of NNX that allows for a more concise and readable code." + "**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code." ] }, { diff --git a/docs/nnx/mnist_tutorial.md b/docs/nnx/mnist_tutorial.md index 1d67b94eb2..6cea6668de 100644 --- a/docs/nnx/mnist_tutorial.md +++ b/docs/nnx/mnist_tutorial.md @@ -14,21 +14,22 @@ jupytext: # MNIST Tutorial -Welcome to NNX! This tutorial will guide you through building and training a simple convolutional -neural network (CNN) on the MNIST dataset using the NNX API. NNX is a Python neural network library +Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional +neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within [Flax](https://github.com/google/flax). +++ -## 1. Install NNX +## 1. Install Flax -Since NNX is under active development, we recommend using the latest version from the Flax GitHub repository: +If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the +following cell: ```{code-cell} ipython3 :tags: [skip-execution] -# !pip install git+https://github.com/google/flax.git +# !pip install flax ``` ## 2. Load the MNIST Dataset @@ -71,12 +72,12 @@ train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).pre test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) ``` -## 3. Define the Network with NNX +## 3. Define the Network with Flax NNX -Create a convolutional neural network with NNX by subclassing `nnx.Module`. +Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`. ```{code-cell} ipython3 -from flax import nnx # NNX API +from flax import nnx # Flax NNX API from functools import partial class CNN(nnx.Module): @@ -116,7 +117,7 @@ nnx.display(y) ## 4. Create Optimizer and Metrics -In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. +In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. ```{code-cell} ipython3 import optax @@ -162,9 +163,9 @@ def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), -except it can transforms functions that contain NNX objects as inputs and outputs. +except it can transforms functions that contain Flax NNX objects as inputs and outputs. -**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because NNX transforms respect reference semantics for NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of NNX that allows for a more concise and readable code. +**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. +++ diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index c4d23736c6..fb326e1dc6 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -4,9 +4,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# NNX Basics\n", + "# Flax NNX Basics\n", "\n", - "NNX is a new Flax API that is designed to make it easier to create, inspect, debug,\n", + "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug,\n", "and analyze neural networks in JAX. It achieves this by adding first class support\n", "for Python reference semantics, allowing users to express their models using regular\n", "Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference\n", @@ -43,18 +43,17 @@ "metadata": {}, "source": [ "## The Module System\n", - "To begin lets see how to create a `Linear` Module using NNX. The main difference between \n", - "NNX and Module systems like Haiku or Linen is that in NNX everything is **explicit**. This \n", + "To begin lets see how to create a `Linear` Module using Flax. The main difference between \n", + "Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This \n", "means among other things that 1) the Module itself holds the state (e.g. parameters) directly, \n", "2) the RNG state is threaded by the user, and 3) all shape information must be provided on \n", "initialization (no shape inference).\n", "\n", "As shown next, dynamic state is usually stored in `nnx.Param`s, and static state \n", - "(all types not handled by NNX) such as integers or strings are stored directly. \n", + "(all types not handled by Flax) such as integers or strings are stored directly. \n", "Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic \n", "state, although storing them inside `nnx.Variable`s such as `Param` is preferred.\n", - "Also, the `nnx.Rngs` object by can be used to get new unique keys based on a root \n", - "key passed to the constructor." + "Also, `nnx.Rngs` can be used to get new unique keys starting from a root key." ] }, { @@ -137,7 +136,7 @@ "### Stateful Computation\n", "\n", "Implementing layers such as `BatchNorm` requires performing state updates during the \n", - "forward pass. To implement this in NNX you just create a `Variable` and update its \n", + "forward pass. To implement this in Flax you just create a `Variable` and update its \n", "`.value` during the forward pass." ] }, @@ -176,7 +175,7 @@ "metadata": {}, "source": [ "Mutable references are usually avoided in JAX, however as we'll see in later sections\n", - "NNX provides sound mechanisms to handle them." + "Flax provides sound mechanisms to handle them." ] }, { @@ -232,7 +231,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In NNX `Dropout` is a stateful module that stores an `Rngs` object so that it can generate\n", + "In Flax `Dropout` is a stateful module that stores an `Rngs` object so that it can generate\n", "new masks during the forward pass without the need for the user to pass a new key each time." ] }, @@ -241,7 +240,7 @@ "metadata": {}, "source": [ "#### Model Surgery\n", - "NNX Modules are mutable by default, this means their structure can be changed at any time, \n", + "Flax NNX Modules are mutable by default, this means their structure can be changed at any time, \n", "this makes model surgery quite easy as any submodule attribute can be replaced with anything\n", "else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, \n", "`Variable`s can also be modified or replaced / shared.\n", @@ -296,15 +295,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## NNX Transforms\n", + "## Transforms\n", "\n", - "NNX Transforms extend JAX transforms to support Modules and other objects.\n", + "Flax Transforms extend JAX transforms to support Modules and other objects.\n", "They are supersets of their equivalent JAX counterpart with the addition of\n", "being aware of the object's state and providing additional APIs to transform \n", - "it. One of the main features of NNX Transforms is the preservation of reference semantics, \n", + "it. One of the main features of Flax Transforms is the preservation of reference semantics, \n", "meaning that any mutation of the object graph that occurs inside the transform is\n", "propagated outisde as long as its legal within the transform rules. In practice this\n", - "means that NNX programs can be express using imperative code, highly simplifying\n", + "means that Flax programs can be express using imperative code, highly simplifying\n", "the user experience.\n", "\n", "In the following example we define a `train_step` function that takes a `MLP` model,\n", @@ -431,7 +430,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "How do NNX transforms achieve this? To understand how NNX objects interact with\n", + "How do Flax transforms achieve this? To understand how Flax objects interact with\n", "JAX transforms lets take a look at the Functional API." ] }, diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index 6a5f2a3888..42b67dd547 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -8,9 +8,9 @@ jupytext: jupytext_version: 1.13.8 --- -# NNX Basics +# Flax NNX Basics -NNX is a new Flax API that is designed to make it easier to create, inspect, debug, +Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics, allowing users to express their models using regular Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference @@ -30,18 +30,17 @@ import jax.numpy as jnp ``` ## The Module System -To begin lets see how to create a `Linear` Module using NNX. The main difference between -NNX and Module systems like Haiku or Linen is that in NNX everything is **explicit**. This +To begin lets see how to create a `Linear` Module using Flax. The main difference between +Flax NNX and Module systems like Haiku or Flax Linen is that everything is **explicit**. This means among other things that 1) the Module itself holds the state (e.g. parameters) directly, 2) the RNG state is threaded by the user, and 3) all shape information must be provided on initialization (no shape inference). As shown next, dynamic state is usually stored in `nnx.Param`s, and static state -(all types not handled by NNX) such as integers or strings are stored directly. +(all types not handled by Flax) such as integers or strings are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic state, although storing them inside `nnx.Variable`s such as `Param` is preferred. -Also, the `nnx.Rngs` object by can be used to get new unique keys based on a root -key passed to the constructor. +Also, `nnx.Rngs` can be used to get new unique keys starting from a root key. ```{code-cell} ipython3 class Linear(nnx.Module): @@ -81,7 +80,7 @@ The above visualization by `nnx.display` is generated using the awesome [Treesco ### Stateful Computation Implementing layers such as `BatchNorm` requires performing state updates during the -forward pass. To implement this in NNX you just create a `Variable` and update its +forward pass. To implement this in Flax you just create a `Variable` and update its `.value` during the forward pass. ```{code-cell} ipython3 @@ -101,7 +100,7 @@ print(f'{counter.count.value = }') ``` Mutable references are usually avoided in JAX, however as we'll see in later sections -NNX provides sound mechanisms to handle them. +Flax provides sound mechanisms to handle them. +++ @@ -131,13 +130,13 @@ y = model(x=jnp.ones((3, 2))) nnx.display(model) ``` -In NNX `Dropout` is a stateful module that stores an `Rngs` object so that it can generate +In Flax `Dropout` is a stateful module that stores an `Rngs` object so that it can generate new masks during the forward pass without the need for the user to pass a new key each time. +++ #### Model Surgery -NNX Modules are mutable by default, this means their structure can be changed at any time, +Flax NNX Modules are mutable by default, this means their structure can be changed at any time, this makes model surgery quite easy as any submodule attribute can be replaced with anything else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, `Variable`s can also be modified or replaced / shared. @@ -169,15 +168,15 @@ y = model(x=jnp.ones((3, 2))) nnx.display(model) ``` -## NNX Transforms +## Transforms -NNX Transforms extend JAX transforms to support Modules and other objects. +Flax Transforms extend JAX transforms to support Modules and other objects. They are supersets of their equivalent JAX counterpart with the addition of being aware of the object's state and providing additional APIs to transform -it. One of the main features of NNX Transforms is the preservation of reference semantics, +it. One of the main features of Flax Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is propagated outisde as long as its legal within the transform rules. In practice this -means that NNX programs can be express using imperative code, highly simplifying +means that Flax programs can be express using imperative code, highly simplifying the user experience. In the following example we define a `train_step` function that takes a `MLP` model, @@ -255,7 +254,7 @@ print(f'{y.shape = }') nnx.display(model) ``` -How do NNX transforms achieve this? To understand how NNX objects interact with +How do Flax transforms achieve this? To understand how Flax objects interact with JAX transforms lets take a look at the Functional API. +++ diff --git a/docs/nnx/surgery.ipynb b/docs/nnx/surgery.ipynb index 5cc0b10a14..00a1839ec5 100644 --- a/docs/nnx/surgery.ipynb +++ b/docs/nnx/surgery.ipynb @@ -6,9 +6,11 @@ "source": [ "# Model surgery\n", "\n", + "> **Attention**: This page relates to the new Flax NNX API.\n", + "\n", "In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases:\n", "\n", - "* __Python module manipulation__: Pythonic ways to manipulate sub-modules given a model.\n", + "* __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model.\n", "\n", "* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation.\n", "\n", @@ -100,7 +102,7 @@ "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "def awesome_layer(x): return x\n", "model.linear2 = awesome_layer\n", - "np.testing.assert_allclose(model(x), model.linear1(x))\n" + "np.testing.assert_allclose(model(x), model.linear1(x))" ] }, { @@ -112,7 +114,7 @@ "For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n", "\n", "To create an abstract model,\n", - "* Create a function that returns a valid NNX model; and\n", + "* Create a function that returns a valid Flax NNX model; and\n", "* Run `nnx.eval_shape` (not `jax.eval_shape`) upon it.\n", "\n", "Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information." @@ -188,7 +190,7 @@ "\n", "With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them.\n", "\n", - "This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX) and old weights are no longer naturally compatible. Let's run a simple example here:" + "This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here:" ] }, { @@ -197,7 +199,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Save a version of a model into a checkpoint.\n", + "# Save a version of model into a checkpoint\n", "checkpointer = orbax.PyTreeCheckpointer()\n", "old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)" @@ -314,7 +316,7 @@ "source": [ "## Partial initialization\n", "\n", - "In some cases (such as with Low-Rank Adapation (LoRA)), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization." + "In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization." ] }, { @@ -383,7 +385,7 @@ } ], "source": [ - "# Some pretrained model state.\n", + "# Some pretrained model state\n", "old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "\n", "# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n", @@ -405,6 +407,9 @@ } ], "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, "language_info": { "codemirror_mode": { "name": "ipython", diff --git a/docs/nnx/surgery.md b/docs/nnx/surgery.md index cd7a73a85b..e829f850ce 100644 --- a/docs/nnx/surgery.md +++ b/docs/nnx/surgery.md @@ -1,5 +1,17 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + # Model surgery +> **Attention**: This page relates to the new Flax NNX API. + In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases: * __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model. @@ -10,8 +22,7 @@ In this guide you will learn how to do model surgery with Flax NNX with several * __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method. - -```python +```{code-cell} ipython3 from typing import * from pprint import pprint import functools @@ -30,8 +41,7 @@ import orbax.checkpoint as orbax key = jax.random.key(0) ``` - -```python +```{code-cell} ipython3 class TwoLayerMLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) @@ -48,8 +58,7 @@ Doing model surgery is easiest when you already have a fully fleshed-out model l You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching: - -```python +```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) np.testing.assert_allclose(model(x), model.linear2(model.linear1(x))) @@ -77,7 +86,6 @@ model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) def awesome_layer(x): return x model.linear2 = awesome_layer np.testing.assert_allclose(model(x), model.linear1(x)) - ``` ## Creating an abstract model or state without memory allocation @@ -85,46 +93,20 @@ np.testing.assert_allclose(model(x), model.linear1(x)) For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints. To create an abstract model, -* Create a function that returns a valid NNX model; and +* Create a function that returns a valid Flax NNX model; and * Run `nnx.eval_shape` (not `jax.eval_shape`) upon it. Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information. - -```python +```{code-cell} ipython3 abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) gdef, abs_state = nnx.split(abs_model) pprint(abs_state) ``` - State({ - 'linear1': { - 'bias': VariableState( - type=Param, - value=ShapeDtypeStruct(shape=(4,), dtype=float32) - ), - 'kernel': VariableState( - type=Param, - value=ShapeDtypeStruct(shape=(4, 4), dtype=float32) - ) - }, - 'linear2': { - 'bias': VariableState( - type=Param, - value=ShapeDtypeStruct(shape=(4,), dtype=float32) - ), - 'kernel': VariableState( - type=Param, - value=ShapeDtypeStruct(shape=(4, 4), dtype=float32) - ) - } - }) - - When you fill every `VariableState` leaf's `value`s with real jax arrays, the abstract model becomes equivalent to a real model. - -```python +```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) abs_state['linear1']['kernel'].value = model.linear1.kernel abs_state['linear1']['bias'].value = model.linear1.bias @@ -140,9 +122,8 @@ With the abstract state technique in hand, you can do arbitrary manipulation on This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here: - -```python -# Save a version of a model into a checkpoint. +```{code-cell} ipython3 +# Save a version of model into a checkpoint checkpointer = orbax.PyTreeCheckpointer() old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True) @@ -150,8 +131,7 @@ checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True) In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure: - -```python +```{code-cell} ipython3 class ModifiedTwoLayerMLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.layer1 = nnx.Linear(dim, dim, rngs=rngs) # no longer linear1! @@ -169,17 +149,9 @@ except Exception as e: print(f'This will throw error: {type(e)}: {e}') ``` - This will throw error: : 'layer1' - - - /Users/ivyzheng/envs/py310/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1401: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with. - warnings.warn( - - But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition. - -```python +```{code-cell} ipython3 def module_from_variables_dict(module_factory, variables, map_key_fn): if map_key_fn is None: map_key_fn = lambda path: path @@ -209,31 +181,20 @@ restored_model = module_from_variables_dict( np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4)))) ``` - {'linear1': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)}, - 'kernel': {'raw_value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968], - [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ], - [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951], - [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)}}, - 'linear2': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)}, - 'kernel': {'raw_value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ], - [ 0.41914317, 0.84359694, -0.47937787, -0.49135214], - [-0.46072108, 0.4630125 , 0.39276958, -0.9441406 ], - [-0.6690758 , -0.18474789, -0.57622856, 0.4821079 ]], dtype=float32)}}} - - ## Partial initialization In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization. ++++ + ### Naive partial initialization You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below. > Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output. - -```python -# Some pretrained model state. +```{code-cell} ipython3 +# Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42))) @@ -248,17 +209,11 @@ print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' ' (2 discarded - only lora_a & lora_b are used in model)') ``` - Number of jax arrays in memory at start: 34 - Number of jax arrays in memory midway: 38 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b) - Number of jax arrays in memory at end: 36 (2 discarded - only lora_a & lora_b are used in model) - - ### Memory-efficient partial initialization Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized: - -```python +```{code-cell} ipython3 # Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) @@ -278,7 +233,3 @@ good_model = partial_init(old_state, nnx.Rngs(42)) print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' ' (2 new created - lora_a and lora_b)') ``` - - Number of jax arrays in memory at start: 40 - Number of jax arrays in memory at end: 42 (2 new created - lora_a and lora_b) - diff --git a/docs/nnx/transforms.rst b/docs/nnx/transforms.rst index 8d015a365b..2b5edbd548 100644 --- a/docs/nnx/transforms.rst +++ b/docs/nnx/transforms.rst @@ -1,13 +1,16 @@ -NNX vs JAX Transformations +Flax NNX vs JAX Transformations ========================== -In this guide, you will learn the differences using NNX and JAX transformations, and how to +.. attention:: + This page relates to the new Flax NNX API. + +In this guide, you will learn the differences using Flax NNX and JAX transformations, and how to seamlessly switch between them or use them together. We will be focusing on the ``jit`` and ``grad`` function transformations in this guide. First, let's set up imports and generate some dummy data: -.. testcode:: NNX, JAX +.. testcode:: Flax NNX, JAX from flax import nnx import jax @@ -18,28 +21,28 @@ First, let's set up imports and generate some dummy data: Differences between NNX and JAX transformations *********************************************** -The primary difference between NNX and JAX transformations is that NNX transformations allow you to -transform functions that take in NNX graph objects as arguments (`Module`, `Rngs`, `Optimizer`, etc), +The primary difference between Flax NNX and JAX transformations is that Flax NNX transformations allow you to +transform functions that take in Flax NNX graph objects as arguments (`Module`, `Rngs`, `Optimizer`, etc), even those whose state will be mutated, whereas they aren't recognized in JAX transformations. -Therefore NNX transformations can transform functions that are not pure and make mutations and +Therefore Flax NNX transformations can transform functions that are not pure and make mutations and side-effects. -NNX's `Functional API `_ +Flax NNX's `Functional API `_ provides a way to convert graph structures to pytrees and back. By doing this at every function boundary you can effectively use graph structures with any JAX transform and propagate state updates -in a way consistent with functional purity. NNX custom transforms such as ``nnx.jit`` and ``nnx.grad`` +in a way consistent with functional purity. Flax NNX custom transforms such as ``nnx.jit`` and ``nnx.grad`` simply remove the boilerplate, as a result the code looks stateful. Below is an example of using the ``nnx.jit`` and ``nnx.grad`` transformations compared to using the -``jax.jit`` and ``jax.grad`` transformations. Notice the function signature of NNX-transformed +``jax.jit`` and ``jax.grad`` transformations. Notice the function signature of Flax NNX-transformed functions can accept the ``nnx.Linear`` module directly and can make stateful updates to the module, whereas the function signature of JAX-transformed functions can only accept the pytree-registered ``State`` and ``GraphDef`` objects and must return an updated copy of them to maintain the purity of the transformed function. .. codediff:: - :title: NNX transforms, JAX transforms - :groups: NNX, JAX + :title: Flax NNX transforms, JAX transforms + :groups: Flax NNX, JAX :sync: @nnx.jit @@ -76,15 +79,15 @@ the transformed function. graphdef, state = train_step(graphdef, state, x, y) #! -Mixing NNX and JAX transformations +Mixing Flax NNX and JAX transformations ********************************** -NNX and JAX transformations can be mixed together, so long as the JAX-transformed function is +Flax NNX and JAX transformations can be mixed together, so long as the JAX-transformed function is pure and has valid argument types that are recognized by JAX. .. codediff:: :title: Using ``nnx.jit`` with ``jax.grad``, Using ``jax.jit`` with ``nnx.grad`` - :groups: NNX, JAX + :groups: Flax NNX, JAX :sync: @nnx.jit