From caa061d04531efc8cd6f3dd2438f60ca5554a0ab Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 2 Oct 2024 16:16:36 +0100 Subject: [PATCH] add Why NNX doc --- docs_nnx/guides/why.ipynb | 770 -------------------------------------- docs_nnx/guides/why.md | 409 -------------------- docs_nnx/index.rst | 1 + docs_nnx/why.rst | 416 ++++++++++++++++++++ 4 files changed, 417 insertions(+), 1179 deletions(-) delete mode 100644 docs_nnx/guides/why.ipynb delete mode 100644 docs_nnx/guides/why.md create mode 100644 docs_nnx/why.rst diff --git a/docs_nnx/guides/why.ipynb b/docs_nnx/guides/why.ipynb deleted file mode 100644 index d38fe6c809..0000000000 --- a/docs_nnx/guides/why.ipynb +++ /dev/null @@ -1,770 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Why NNX?\n", - "\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb)\n", - "\n", - "Four years ago we developed the Flax \"Linen\" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years.\n", - "\n", - "We introduced some ideas that have proven to be good:\n", - " - Organizing variables into [collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) or types to support JAX transforms and segregation of different data types in training loops.\n", - " - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms)\n", - " - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses.\n", - "\n", - "However, one choice we made was to use functional \"define by call\" semantics for NN programming via the lazy initialization of parameters. This made for concise (`compact`) implementation code, allowed for a single specification when transforming a layer, and aligned our API with Haiku. Lazy initialization meant that the semantics of modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets.\n", - "\n", - "NNX is an attempt to keep the features that made Linen useful while introducing some new principles:\n", - "\n", - "- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references.\n", - "- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks.\n", - "\n", - "We'd love to hear from any of our users about their thoughts on these ideas.\n", - "\n", - "[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)]\n", - "[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)]" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "metadata": {}, - "outputs": [], - "source": [ - "! pip install -U git+https://github.com/google/flax.git\n", - "from functools import partial\n", - "import jax\n", - "from jax import random, numpy as jnp\n", - "from flax import nnx" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### NNX is Pythonic\n", - "The main feature of NNX Module is that it adheres to Python semantics. This means that:\n", - "\n", - "* fields are mutable so you can perform inplace updates\n", - "* Module references can be shared between multiple Modules\n", - "* Module construction implies parameter initialization\n", - "* Module methods can be called directly" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model = CounterLinear(\n", - " linear=Linear(\n", - " in_features=4,\n", - " out_features=4,\n", - " use_bias=True,\n", - " dtype=None,\n", - " param_dtype=,\n", - " precision=None,\n", - " kernel_init=.init at 0x7f3dc9ad3370>,\n", - " bias_init=,\n", - " dot_general=\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "class Count(nnx.Variable): # custom Variable types define the \"collections\"\n", - " pass\n", - "\n", - "\n", - "class CounterLinear(nnx.Module):\n", - " def __init__(self, din, dout, *, rngs): # explicit RNG threading\n", - " self.linear = nnx.Linear(din, dout, rngs=rngs)\n", - " self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections\n", - "\n", - " def __call__(self, x):\n", - " self.count.value += 1 # in-place stateful updates\n", - " return self.linear(x)\n", - "\n", - "\n", - "model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", - "y = model(jnp.ones((2, 4))) # call methods directly\n", - "\n", - "print(f'{model = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Because NNX Modules contain their own state, they are very easily to inspect:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model.count = Array(1, dtype=int32)\n", - "model.linear.kernel = Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],\n", - " [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n", - " [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n", - " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n" - ] - } - ], - "source": [ - "print(f'{model.count = }')\n", - "print(f'{model.linear.kernel = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Intuitive Surgery\n", - "\n", - "In NNX surgery can be done at the Module level by simply updating / replacing existing fields." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.7531997, 1.6318591, 2.1417565, 3.120555 ],\n", - " [1.7531997, 1.6318591, 2.1417565, 3.120555 ]], dtype=float32)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# pretend this came from a checkpoint or elsewhere:\n", - "pretrained_weight = random.uniform(random.key(0), (4, 4))\n", - "\n", - "# you can replace weights directly\n", - "model.linear.kernel = pretrained_weight\n", - "y = model(jnp.ones((2, 4)))\n", - "y" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "outputId": "5190ac7b-12f7-4400-d5bb-f91b97a557b6" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.624419 , 0.8313738 , 0.37612876, 1.9937458 ],\n", - " [1.624419 , 0.8313738 , 0.37612876, 1.9937458 ]], dtype=float32)" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def load_pretrained_fragment():\n", - " # pretend this inits / loads some fragment of a model\n", - " replacement = nnx.Linear(4, 4, rngs=nnx.Rngs(1))\n", - " return replacement\n", - "\n", - "# you can replace modules directly\n", - "model.linear = load_pretrained_fragment()\n", - "y = model(jnp.ones((2, 4)))\n", - "y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Not only is this easier than messing with dictionary structures and aligning that with code changes, but one can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before)." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "rngs = nnx.Rngs(0)\n", - "model = nnx.Sequence(\n", - " [\n", - " nnx.Conv(1, 16, [3, 3], padding='SAME', rngs=rngs),\n", - " partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),\n", - " nnx.Conv(16, 32, [3, 3], padding='SAME', rngs=rngs),\n", - " partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),\n", - " lambda x: x.reshape((x.shape[0], -1)), # flatten\n", - " nnx.Linear(32 * 7 * 7, 10, rngs=rngs),\n", - " ]\n", - ")\n", - "\n", - "y = model(jnp.ones((2, 28, 28, 1)))\n", - "\n", - "# Do some weird surgery of the stack:\n", - "for i, layer in enumerate(model):\n", - " if isinstance(layer, nnx.Conv):\n", - " model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs)\n", - "\n", - "y = model(jnp.ones((2, 28, 28, 1)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that here we are replacing `Conv` with `Linear` as a silly example, but in reality you would do things like replacing a layer with its quantized version, or changing a layer with an optimized version, etc." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Interacting with JAX is easy\n", - "\n", - "While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations and other value-based, functional code.\n", - "\n", - "NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n", - "\n", - "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the graphdef structure of the Module." - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "metadata": { - "outputId": "9a3f378b-739e-4f45-9968-574651200ede" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state = State({\n", - " 'count': Array(0, dtype=int32),\n", - " 'linear/bias': Array([0., 0., 0., 0.], dtype=float32),\n", - " 'linear/kernel': Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],\n", - " [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n", - " [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n", - " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n", - "})\n", - "\n", - "graphdef = GraphDef(\n", - " type=CounterLinear,\n", - " index=0,\n", - " static_fields=(),\n", - " variables=(('count', Count(\n", - " value=Empty\n", - " )),),\n", - " submodules=(\n", - " ('linear', GraphDef(\n", - " type=Linear,\n", - " index=1,\n", - " static_fields=(('bias_init', ), ('dot_general', ), ('dtype', None), ('in_features', 4), ('kernel_init', .init at 0x7f3dc9ad3370>), ('out_features', 4), ('param_dtype', ), ('precision', None), ('use_bias', True)),\n", - " variables=(('bias', Param(\n", - " value=Empty\n", - " )), ('kernel', Param(\n", - " value=Empty\n", - " ))),\n", - " submodules=()\n", - " ))\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "model = CounterLinear(4, 4, rngs=nnx.Rngs(0))\n", - "\n", - "graphdef, state = model.split()\n", - "\n", - "# state is a dictionary-like JAX pytree\n", - "print(f'{state = }')\n", - "\n", - "# graphdef is also a JAX pytree, but containing no data, just metadata\n", - "print(f'\\n{graphdef = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `GraphDef.merge` method allows you to take a `GraphDef` and one or more `State` objects and merge them back into a `Module` object.\n", - "\n", - "Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example:" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "metadata": { - "outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y.shape = (2, 4)\n", - "state[\"count\"] = Array(1, dtype=int32)\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", - " model = graphdef.merge(state)\n", - " y = model(x)\n", - " state, _ = model.split()\n", - " return y, state\n", - "\n", - "x = jnp.ones((2, 4))\n", - "y, state = forward(graphdef,state, x)\n", - "\n", - "print(f'{y.shape = }')\n", - "print(f'{state[\"count\"] = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Custom lifting and transformation\n", - "\n", - "By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior.\n", - "\n", - "One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes it very easy to implement custom lifted Modules or bespoke custom functional transforms for specific use cases.\n", - "\n", - "As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple.\n", - "\n", - "It uses the single additional method `update` to locally modify model state." - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": { - "outputId": "fdd212d7-4994-4fa5-d922-5a7d7cfad3e3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y.shape = (8, 4)\n", - "ensemble.models.count = Array(1, dtype=int32)\n", - "state = State({\n", - " 'models/count': (),\n", - " 'models/linear/bias': (8, 4),\n", - " 'models/linear/kernel': (8, 4, 4)\n", - "})\n" - ] - } - ], - "source": [ - "class LinearEnsemble(nnx.Module):\n", - " def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs):\n", - " # get raw rng seeds\n", - " keys = rngs.fork(num_models) # split all keys into `num_models`\n", - "\n", - " # define pure init fn and vmap\n", - " def vmap_init(keys):\n", - " return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split(\n", - " nnx.Param, Count\n", - " )\n", - " params, counts, graphdef = jax.vmap(\n", - " vmap_init, in_axes=(0,), out_axes=(0, None, None)\n", - " )(keys)\n", - "\n", - " # update wrapped submodule reference\n", - " self.models = graphdef.merge(params, counts)\n", - "\n", - " def __call__(self, x):\n", - " # get module values, define pure fn,\n", - " # notice that we split the data into two collections by their types.\n", - " params, counts, graphdef = self.models.split(nnx.Param, Count)\n", - "\n", - " # define pure init fn and vmap\n", - " def vmap_apply(x, params, counts, graphdef):\n", - " model = graphdef.merge(params, counts)\n", - " y = model(x)\n", - " params, counts, graphdef = model.split(nnx.Param, Count)\n", - " return y, params, counts, graphdef\n", - "\n", - " y, params, counts, graphdef = jax.vmap(\n", - " vmap_apply,\n", - " in_axes=(None, 0, None, None),\n", - " out_axes=(0, 0, None, None)\n", - " )(x, params, counts, graphdef)\n", - "\n", - " # update wrapped module\n", - " # uses `update` to integrate the new state\n", - " self.models.update(params, counts, graphdef)\n", - " return y\n", - "\n", - "x = jnp.ones((4,))\n", - "ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0))\n", - "\n", - "# forward pass\n", - "y = ensemble(x)\n", - "\n", - "print(f'{y.shape = }')\n", - "print(f'{ensemble.models.count = }')\n", - "print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Convenience lifted transforms" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Like linen, for convenience we still provide simple lifted transforms for standard JAX transforms, usable as class transforms and decorators. We've endeavored to simplify the API for scan and vmap compared to the flax specifications." - ] - }, - { - "cell_type": "code", - "execution_count": 112, - "metadata": { - "outputId": "c4800a49-efd1-4ee5-e703-6e63e18da4cb" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'scan_module/bias': Array([[0., 0.],\n", - " [0., 0.],\n", - " [0., 0.],\n", - " [0., 0.]], dtype=float32),\n", - " 'scan_module/kernel': Array([[[-0.32325608, 0.16164146],\n", - " [ 0.46505648, -0.34060344]],\n", - " \n", - " [[-1.1558908 , 1.2445341 ],\n", - " [-1.3710847 , -0.1787171 ]],\n", - " \n", - " [[-0.68510336, 0.25847596],\n", - " [ 1.0730107 , -0.11857361]],\n", - " \n", - " [[-0.01770882, 0.5472832 ],\n", - " [-0.84826714, 0.17867221]]], dtype=float32)\n", - "})" - ] - }, - "execution_count": 112, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# class transform:\n", - "ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)\n", - "\n", - "scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n", - "scanned.get_state()" - ] - }, - { - "cell_type": "code", - "execution_count": 113, - "metadata": { - "outputId": "9efd6e71-d180-4674-ade0-2b02057a400b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'model/bias': Array([[0., 0.],\n", - " [0., 0.],\n", - " [0., 0.],\n", - " [0., 0.]], dtype=float32),\n", - " 'model/kernel': Array([[[-0.32325608, 0.16164146],\n", - " [ 0.46505648, -0.34060344]],\n", - " \n", - " [[-1.1558908 , 1.2445341 ],\n", - " [-1.3710847 , -0.1787171 ]],\n", - " \n", - " [[-0.68510336, 0.25847596],\n", - " [ 1.0730107 , -0.11857361]],\n", - " \n", - " [[-0.01770882, 0.5472832 ],\n", - " [-0.84826714, 0.17867221]]], dtype=float32)\n", - "})" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# method decorators:\n", - "\n", - "class ScannedLinear(nnx.Module):\n", - "\n", - " @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4)\n", - " def __init__(self, din, dout, *, rngs: nnx.Rngs):\n", - " self.model = nnx.Linear(din, dout, rngs=nnx.Rngs(rngs))\n", - "\n", - " @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4)\n", - " def __call__(self, x):\n", - " return self.model(x)\n", - "\n", - "scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n", - "scanned.get_state()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Aside: Why aren't Modules Pytrees?\n", - "\n", - "A common questions is why aren't NNX Modules registered as Pytrees? (in the style of Equinox, Treex, PytreeClass, etc.) It _is_ trivial to define a pytree registration in terms of `split`/`merge`.\n", - "\n", - "The problem is that Pytrees impose value semantics (referencial transparency) while Modules assume reference semantics, and therefore it is dangerous in general to automatically treat Modules as Pytrees.\n", - "\n", - "As an example, lets take a look at what would happen if we allowed this very simple program to be valid:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def f(m1: nnx.Module, m2: nnx.Module):\n", - " return m1, m2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong?\n", - "\n", - "There are two main problems with this:\n", - "* Shared references are not maintained, that is, if `m1.shared` is the same as `m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`.\n", - "* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undesirable asymmetry and `jit` would no longer be a no-op." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Standardized \"Hooks\"\n", - "\n", - "NNX introduces a standard getter/setter/creator interface for custom variables (similar to Haiku hooks). This is used internally to support SPMD metadata for managing sharding information, but is available for user-defined applications." - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": { - "outputId": "c4e6586a-bfe0-4f26-d05b-8c9e395971b2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "self.kernel.shape = (4, 8)\n", - "outer kernel shape = (8, 4)\n" - ] - } - ], - "source": [ - "class TransposedParam(nnx.Variable):\n", - " def create_value(self, value):\n", - " return value.T # called on variable creation to transform initial value\n", - " def get_value(self):\n", - " return self.value.T # called when value fetched via module getattr\n", - " def set_value(self, value):\n", - " return self.replace(value=value.T) # called when setting value from module setattr\n", - "\n", - "\n", - "class OddLinear(nnx.Module):\n", - " def __init__(self, din, dout, *, rngs):\n", - " self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)))\n", - " self.bias = nnx.Param(jnp.zeros((dout,)))\n", - "\n", - " def __call__(self, x):\n", - " print(f'{self.kernel.shape = }')\n", - " return x @ self.kernel + self.bias\n", - "\n", - "\n", - "model = OddLinear(4, 8, rngs=nnx.Rngs(0))\n", - "y = model(jnp.ones((2, 4)))\n", - "\n", - "print(f'outer kernel shape = {model.split()[0][\"kernel\"].shape}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "SPMD metadata is handled using `nnx.with_partitioning` helpers, but it's easy to add one's own metadata schema:" - ] - }, - { - "cell_type": "code", - "execution_count": 114, - "metadata": { - "outputId": "ef312738-0f56-4c0e-9aaf-3319d131f1a2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state.variables['kernel'].meta='foo'\n", - "state.variables['kernel'].other_meta=0\n", - "state.variables['bias'].meta='bar'\n", - "state.variables['bias'].other_meta=1\n" - ] - } - ], - "source": [ - "class MetadataParam(nnx.Param):\n", - " def __init__(self, *args, **kwargs):\n", - " for key in kwargs:\n", - " setattr(self, key, kwargs[key])\n", - " super().__init__(*args)\n", - "\n", - "\n", - "class AnnotatedLinear(nnx.Module):\n", - " def __init__(self, din, dout, *, rngs):\n", - " self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)), meta='foo', other_meta=0)\n", - " self.bias = TransposedParam(jnp.zeros((dout,)), meta='bar', other_meta=1)\n", - "\n", - " def __call__(self, x):\n", - " return x @ self.kernel + self.bias\n", - "\n", - "\n", - "model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0))\n", - "y = model(jnp.ones((2, 4)))\n", - "\n", - "graphdef, state = model.split()\n", - "\n", - "print(f\"{state.variables['kernel'].meta=}\\n{state.variables['kernel'].other_meta=}\")\n", - "print(f\"{state.variables['bias'].meta=}\\n{state.variables['bias'].other_meta=}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Shape Inference\n", - "\n", - "Shape inference is still possible in NNX using abstract evaluation when it's really needed, it just isn't automatic." - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": { - "outputId": "942a3788-bcbf-426d-87e6-c5a041172c64" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'encoder/bias': (4,),\n", - " 'encoder/kernel': (3, 3, 3, 4),\n", - " 'linear/bias': (4,),\n", - " 'linear/kernel': (144, 4)\n", - "})" - ] - }, - "execution_count": 129, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def batched_flatten(x):\n", - " return jnp.reshape(x, (x.shape[0], -1))\n", - "\n", - "class Example(nnx.Module):\n", - " def __init__(self, *,\n", - " in_filters=3,\n", - " out_filters=4,\n", - " input_shape=None, # provide an example input size\n", - " rngs):\n", - " self.encoder = nnx.Conv(in_filters, out_filters,\n", - " kernel_size=(3, 3),\n", - " strides=(1, 1),\n", - " padding=\"SAME\",\n", - " rngs=rngs)\n", - " # calculate the flattened shape post-conv using jax.eval_shape\n", - " encoded_shape = jax.eval_shape(\n", - " lambda x: batched_flatten(self.encoder(x)),\n", - " jax.ShapeDtypeStruct(input_shape, jnp.float32)\n", - " ).shape\n", - " # use this shape information to continue initializing\n", - " self.linear = nnx.Linear(encoded_shape[-1], 4, rngs=rngs)\n", - "\n", - " def __call__(self, x):\n", - " x = self.encoder(x)\n", - " x = batched_flatten(x)\n", - " return self.linear(x)\n", - "\n", - "model = Example(in_filters=3,\n", - " out_filters=4,\n", - " input_shape=(2, 6, 6, 3),\n", - " rngs=nnx.Rngs(0))\n", - "\n", - "graphdef, state = model.split()\n", - "jax.tree.map(jnp.shape, state)" - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst" - }, - "language_info": { - "name": "python", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs_nnx/guides/why.md b/docs_nnx/guides/why.md deleted file mode 100644 index b080319be7..0000000000 --- a/docs_nnx/guides/why.md +++ /dev/null @@ -1,409 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.13.8 ---- - -# Why NNX? - - -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb) - -Four years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years. - -We introduced some ideas that have proven to be good: - - Organizing variables into [collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) or types to support JAX transforms and segregation of different data types in training loops. - - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms) - - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses. - -However, one choice we made was to use functional "define by call" semantics for NN programming via the lazy initialization of parameters. This made for concise (`compact`) implementation code, allowed for a single specification when transforming a layer, and aligned our API with Haiku. Lazy initialization meant that the semantics of modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets. - -NNX is an attempt to keep the features that made Linen useful while introducing some new principles: - -- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references. -- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks. - -We'd love to hear from any of our users about their thoughts on these ideas. - -[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)] -[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)] - -```{code-cell} -! pip install -U git+https://github.com/google/flax.git -from functools import partial -import jax -from jax import random, numpy as jnp -from flax import nnx -``` - -### NNX is Pythonic -The main feature of NNX Module is that it adheres to Python semantics. This means that: - -* fields are mutable so you can perform inplace updates -* Module references can be shared between multiple Modules -* Module construction implies parameter initialization -* Module methods can be called directly - -```{code-cell} -:outputId: d8ef66d5-6866-4d5c-94c2-d22512bfe718 - -class Count(nnx.Variable): # custom Variable types define the "collections" - pass - - -class CounterLinear(nnx.Module): - def __init__(self, din, dout, *, rngs): # explicit RNG threading - self.linear = nnx.Linear(din, dout, rngs=rngs) - self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections - - def __call__(self, x): - self.count.value += 1 # in-place stateful updates - return self.linear(x) - - -model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method -y = model(jnp.ones((2, 4))) # call methods directly - -print(f'{model = }') -``` - -Because NNX Modules contain their own state, they are very easily to inspect: - -```{code-cell} -:outputId: 10a46b0f-2993-4677-c26d-36a4ddf33449 - -print(f'{model.count = }') -print(f'{model.linear.kernel = }') -``` - -#### Intuitive Surgery - -In NNX surgery can be done at the Module level by simply updating / replacing existing fields. - -```{code-cell} -:outputId: e6f86be8-3537-4c48-f471-316ee0fb6c45 - -# pretend this came from a checkpoint or elsewhere: -pretrained_weight = random.uniform(random.key(0), (4, 4)) - -# you can replace weights directly -model.linear.kernel = pretrained_weight -y = model(jnp.ones((2, 4))) -y -``` - -```{code-cell} -:outputId: 5190ac7b-12f7-4400-d5bb-f91b97a557b6 - -def load_pretrained_fragment(): - # pretend this inits / loads some fragment of a model - replacement = nnx.Linear(4, 4, rngs=nnx.Rngs(1)) - return replacement - -# you can replace modules directly -model.linear = load_pretrained_fragment() -y = model(jnp.ones((2, 4))) -y -``` - -Not only is this easier than messing with dictionary structures and aligning that with code changes, but one can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before). - -```{code-cell} -rngs = nnx.Rngs(0) -model = nnx.Sequence( - [ - nnx.Conv(1, 16, [3, 3], padding='SAME', rngs=rngs), - partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)), - nnx.Conv(16, 32, [3, 3], padding='SAME', rngs=rngs), - partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)), - lambda x: x.reshape((x.shape[0], -1)), # flatten - nnx.Linear(32 * 7 * 7, 10, rngs=rngs), - ] -) - -y = model(jnp.ones((2, 28, 28, 1))) - -# Do some weird surgery of the stack: -for i, layer in enumerate(model): - if isinstance(layer, nnx.Conv): - model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs) - -y = model(jnp.ones((2, 28, 28, 1))) -``` - -Note that here we are replacing `Conv` with `Linear` as a silly example, but in reality you would do things like replacing a layer with its quantized version, or changing a layer with an optimized version, etc. - -+++ - -### Interacting with JAX is easy - -While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations and other value-based, functional code. - -NNX has two very simple APIs to interact with JAX: `split` and `merge`. - -The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the graphdef structure of the Module. - -```{code-cell} -:outputId: 9a3f378b-739e-4f45-9968-574651200ede - -model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) - -graphdef, state = model.split() - -# state is a dictionary-like JAX pytree -print(f'{state = }') - -# graphdef is also a JAX pytree, but containing no data, just metadata -print(f'\n{graphdef = }') -``` - -The `GraphDef.merge` method allows you to take a `GraphDef` and one or more `State` objects and merge them back into a `Module` object. - -Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example: - -```{code-cell} -:outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d - -@jax.jit -def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array): - model = graphdef.merge(state) - y = model(x) - state, _ = model.split() - return y, state - -x = jnp.ones((2, 4)) -y, state = forward(graphdef,state, x) - -print(f'{y.shape = }') -print(f'{state["count"] = }') -``` - -#### Custom lifting and transformation - -By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior. - -One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes it very easy to implement custom lifted Modules or bespoke custom functional transforms for specific use cases. - -As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple. - -It uses the single additional method `update` to locally modify model state. - -```{code-cell} -:outputId: fdd212d7-4994-4fa5-d922-5a7d7cfad3e3 - -class LinearEnsemble(nnx.Module): - def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs): - # get raw rng seeds - keys = rngs.fork(num_models) # split all keys into `num_models` - - # define pure init fn and vmap - def vmap_init(keys): - return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split( - nnx.Param, Count - ) - params, counts, graphdef = jax.vmap( - vmap_init, in_axes=(0,), out_axes=(0, None, None) - )(keys) - - # update wrapped submodule reference - self.models = graphdef.merge(params, counts) - - def __call__(self, x): - # get module values, define pure fn, - # notice that we split the data into two collections by their types. - params, counts, graphdef = self.models.split(nnx.Param, Count) - - # define pure init fn and vmap - def vmap_apply(x, params, counts, graphdef): - model = graphdef.merge(params, counts) - y = model(x) - params, counts, graphdef = model.split(nnx.Param, Count) - return y, params, counts, graphdef - - y, params, counts, graphdef = jax.vmap( - vmap_apply, - in_axes=(None, 0, None, None), - out_axes=(0, 0, None, None) - )(x, params, counts, graphdef) - - # update wrapped module - # uses `update` to integrate the new state - self.models.update(params, counts, graphdef) - return y - -x = jnp.ones((4,)) -ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0)) - -# forward pass -y = ensemble(x) - -print(f'{y.shape = }') -print(f'{ensemble.models.count = }') -print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}') -``` - -#### Convenience lifted transforms - -+++ - -Like linen, for convenience we still provide simple lifted transforms for standard JAX transforms, usable as class transforms and decorators. We've endeavored to simplify the API for scan and vmap compared to the flax specifications. - -```{code-cell} -:outputId: c4800a49-efd1-4ee5-e703-6e63e18da4cb - -# class transform: -ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4) - -scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0)) -scanned.get_state() -``` - -```{code-cell} -:outputId: 9efd6e71-d180-4674-ade0-2b02057a400b - -# method decorators: - -class ScannedLinear(nnx.Module): - - @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4) - def __init__(self, din, dout, *, rngs: nnx.Rngs): - self.model = nnx.Linear(din, dout, rngs=nnx.Rngs(rngs)) - - @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4) - def __call__(self, x): - return self.model(x) - -scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0)) -scanned.get_state() -``` - -#### Aside: Why aren't Modules Pytrees? - -A common questions is why aren't NNX Modules registered as Pytrees? (in the style of Equinox, Treex, PytreeClass, etc.) It _is_ trivial to define a pytree registration in terms of `split`/`merge`. - -The problem is that Pytrees impose value semantics (referencial transparency) while Modules assume reference semantics, and therefore it is dangerous in general to automatically treat Modules as Pytrees. - -As an example, lets take a look at what would happen if we allowed this very simple program to be valid: - -```{code-cell} -@jax.jit -def f(m1: nnx.Module, m2: nnx.Module): - return m1, m2 -``` - -Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong? - -There are two main problems with this: -* Shared references are not maintained, that is, if `m1.shared` is the same as `m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`. -* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undesirable asymmetry and `jit` would no longer be a no-op. - -+++ - -### Standardized "Hooks" - -NNX introduces a standard getter/setter/creator interface for custom variables (similar to Haiku hooks). This is used internally to support SPMD metadata for managing sharding information, but is available for user-defined applications. - -```{code-cell} -:outputId: c4e6586a-bfe0-4f26-d05b-8c9e395971b2 - -class TransposedParam(nnx.Variable): - def create_value(self, value): - return value.T # called on variable creation to transform initial value - def get_value(self): - return self.value.T # called when value fetched via module getattr - def set_value(self, value): - return self.replace(value=value.T) # called when setting value from module setattr - - -class OddLinear(nnx.Module): - def __init__(self, din, dout, *, rngs): - self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout))) - self.bias = nnx.Param(jnp.zeros((dout,))) - - def __call__(self, x): - print(f'{self.kernel.shape = }') - return x @ self.kernel + self.bias - - -model = OddLinear(4, 8, rngs=nnx.Rngs(0)) -y = model(jnp.ones((2, 4))) - -print(f'outer kernel shape = {model.split()[0]["kernel"].shape}') -``` - -SPMD metadata is handled using `nnx.with_partitioning` helpers, but it's easy to add one's own metadata schema: - -```{code-cell} -:outputId: ef312738-0f56-4c0e-9aaf-3319d131f1a2 - -class MetadataParam(nnx.Param): - def __init__(self, *args, **kwargs): - for key in kwargs: - setattr(self, key, kwargs[key]) - super().__init__(*args) - - -class AnnotatedLinear(nnx.Module): - def __init__(self, din, dout, *, rngs): - self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)), meta='foo', other_meta=0) - self.bias = TransposedParam(jnp.zeros((dout,)), meta='bar', other_meta=1) - - def __call__(self, x): - return x @ self.kernel + self.bias - - -model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0)) -y = model(jnp.ones((2, 4))) - -graphdef, state = model.split() - -print(f"{state.variables['kernel'].meta=}\n{state.variables['kernel'].other_meta=}") -print(f"{state.variables['bias'].meta=}\n{state.variables['bias'].other_meta=}") -``` - -## Shape Inference - -Shape inference is still possible in NNX using abstract evaluation when it's really needed, it just isn't automatic. - -```{code-cell} -:outputId: 942a3788-bcbf-426d-87e6-c5a041172c64 - -def batched_flatten(x): - return jnp.reshape(x, (x.shape[0], -1)) - -class Example(nnx.Module): - def __init__(self, *, - in_filters=3, - out_filters=4, - input_shape=None, # provide an example input size - rngs): - self.encoder = nnx.Conv(in_filters, out_filters, - kernel_size=(3, 3), - strides=(1, 1), - padding="SAME", - rngs=rngs) - # calculate the flattened shape post-conv using jax.eval_shape - encoded_shape = jax.eval_shape( - lambda x: batched_flatten(self.encoder(x)), - jax.ShapeDtypeStruct(input_shape, jnp.float32) - ).shape - # use this shape information to continue initializing - self.linear = nnx.Linear(encoded_shape[-1], 4, rngs=rngs) - - def __call__(self, x): - x = self.encoder(x) - x = batched_flatten(x) - return self.linear(x) - -model = Example(in_filters=3, - out_filters=4, - input_shape=(2, 6, 6, 3), - rngs=nnx.Rngs(0)) - -graphdef, state = model.split() -jax.tree.map(jnp.shape, state) -``` diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index 8ee8676d4d..b3dfe5c852 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -185,6 +185,7 @@ Learn more nnx_basics mnist_tutorial + why guides/index examples/index The Flax philosophy diff --git a/docs_nnx/why.rst b/docs_nnx/why.rst new file mode 100644 index 0000000000..cfc1d9a941 --- /dev/null +++ b/docs_nnx/why.rst @@ -0,0 +1,416 @@ +Why NNX? +======== + +Years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling +and performance. We've learned a lot from our users over these years. We introduced some ideas that have proven to be good: + +* Organizing variables into `collections `_. +* Automatic and efficient `PRNG management `_. +* `Variable Metadata `_ + for SPMD annotations, optimizer metadata, etc. + +One choice we made was to use functional (``compact``) semantics for NN programming via the lazy initialization of parameters, +this made for concise implementation code and aligned our API with Haiku. However, this also meant that the semantics of +modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured +the core ideas of transformations on neural nets. + +.. testsetup:: Linen, NNX + + import jax + from jax import random, numpy as jnp + from flax import nnx + import flax.linen as nn + +Introducing Flax NNX +-------------------- +Flax NNX is an attempt to keep the features that made Linen useful while introducing some new principles. +The central idea behind Flax NNX is to introduce reference semantics into JAX. These are its main features: + +- **Pythonic**: supports regular Python semantics for Modules, including for mutability and shared references. +- **Simple**: many of the complex APIs in Flax Linen are either simplified using Python idioms or removed entirely. +- **Better JAX integration**: both by making custom transforms adopt the same APIs as JAX transforms, and by making + it easier to use JAX transforms directly. + +Here's an example of a simple Flax NNX program that illustrates many of the points above: + +.. testcode:: NNX + + from flax import nnx + import optax + + + class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing + + @nnx.jit # automatic state management for JAX transforms + def train_step(model, optimizer, x, y): + def loss_fn(model): + y_pred = model(x) # call methods directly + return ((y_pred - y) ** 2).mean() + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) # in-place updates + + return loss + +Improvements +------------ +Through the rest of this document, we'll key examples of how Flax NNX improves on Flax Linen. + +Inspection +^^^^^^^^^^ +The first improvement is that Flax NNX modules are regular Python objects, so you can easily +construct and inspect them. Because Flax Linen Modules are lazy, some attributes are not available +upon construction and are only accesible at runtime. This makes it hard to inspect and debug. + +.. codediff:: + :title: Linen, NNX + :sync: + + class Block(nn.Module): + def setup(self): + self.linear = nn.Dense(10) + + block = Block() + + try: + block.linear # AttributeError: "Block" object has no attribute "linear". + except AttributeError as e: + pass + + + + + + ... + + --- + + class Block(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(5, 10, rngs=rngs) + + block = Block(nnx.Rngs(0)) + + + block.linear + # Linear( + # kernel=Param( + # value=Array(shape=(5, 10), dtype=float32) + # ), + # bias=Param( + # value=Array(shape=(10,), dtype=float32) + # ), + # ... + +Notice that in Flax NNX there is no shape inference so both the input and output shapes must be provided +to the Linear module. This is a tradeoff that allows for more explicit and predictable behavior. + +Running Computation +^^^^^^^^^^^^^^^^^^^ +In Flax Linen, all top-level computation must be done through the ``init`` or ``apply`` methods and the +parameters or any other type of state is handled as a separate structure. This creates an asymmetry +between code that runs inside ``apply`` that can run methods and other Modules directly, and code +outside of ``apply`` that must use the ``apply`` method. In Flax NNX, there's no special context +as parameters are held as attributes and methods can be called directly. + +.. codediff:: + :title: Linen, NNX + :sync: + + Encoder = lambda: nn.Dense(10) + Decoder = lambda: nn.Dense(2) + + class AutoEncoder(nn.Module): + def setup(self): + self.encoder = Encoder() + self.decoder = Decoder() + + def __call__(self, x) -> jax.Array: + return self.decoder(self.encoder(x)) + + def encode(self, x) -> jax.Array: + return self.encoder(x) + + x = jnp.ones((1, 2)) + model = AutoEncoder() + params = model.init(random.key(0), x)['params'] + + y = model.apply({'params': params}, x) + z = model.apply({'params': params}, x, method='encode') + y = Decoder().apply({'params': params['decoder']}, z) + + --- + + Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs) + Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs) + + class AutoEncoder(nnx.Module): + def __init__(self, rngs): + self.encoder = Encoder(rngs) + self.decoder = Decoder(rngs) + + def __call__(self, x) -> jax.Array: + return self.decoder(self.encoder(x)) + + def encode(self, x) -> jax.Array: + return self.encoder(x) + + x = jnp.ones((1, 2)) + model = AutoEncoder(nnx.Rngs(0)) + + + y = model(x) + z = model.encode(x) + y = model.decoder(z) + +Note that in Linen, calling submodules directly is not possible as they are not initialized. +So you must construct a new instance and provide proper parameter structure. In NNX +you can call submodules directly without any issues. + +State Handling +^^^^^^^^^^^^^^ +One of the areas where Flax Linen is notoriously complex is in handling state. When you either use a +Dropout layer or a BatchNorm layer, or both, you suddenly have to handle the new state and use it to +configure the ``apply`` method. In Flax NNX, state is kept inside the Module and is mutable, so it can +just be called directly. + +.. codediff:: + :title: Linen, NNX + :sync: + + class Block(nn.Module): + train: bool + + def setup(self): + self.linear = nn.Dense(10) + self.bn = nn.BatchNorm(use_running_average=not self.train) + self.dropout = nn.Dropout(0.1, deterministic=not self.train) + + def __call__(self, x): + return nn.relu(self.dropout(self.bn(self.linear(x)))) + + x = jnp.ones((1, 5)) + model = Block(train=True) + vs = model.init(random.key(0), x) + params, batch_stats = vs['params'], vs['batch_stats'] + + y, updates = model.apply( + {'params': params, 'batch_stats': batch_stats}, + x, + rngs={'dropout': random.key(1)}, + mutable=['batch_stats'], + ) + batch_stats = updates['batch_stats'] + + --- + + class Block(nnx.Module): + + + def __init__(self, rngs): + self.linear = nnx.Linear(5, 10, rngs=rngs) + self.bn = nnx.BatchNorm(10, rngs=rngs) + self.dropout = nnx.Dropout(0.1, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + x = jnp.ones((1, 5)) + model = Block(nnx.Rngs(0)) + + + + y = model(x) + + + + + + ... + +The main benefit is that this usually means you don't have to change the training code when you add +a new stateful layers. Layers that handle state are also very easy to implement in Flax NNX, below +is a simplified version of a BatchNorm layer that updates the mean and variance every time it's called. + +.. testcode:: NNX + + class BatchNorm(nnx.Module): + def __init__(self, features: int, mu: float = 0.95): + # Variables + self.scale = nnx.Param(jax.numpy.ones((features,))) + self.bias = nnx.Param(jax.numpy.zeros((features,))) + self.mean = nnx.BatchStat(jax.numpy.zeros((features,))) + self.var = nnx.BatchStat(jax.numpy.ones((features,))) + self.mu = mu # static + + def __call__(self, x): + mean = jax.numpy.mean(x, axis=-1) + var = jax.numpy.var(x, axis=-1) + # ema updates + self.mean.value = self.mu * self.mean + (1 - self.mu) * mean + self.var.value = self.mu * self.var + (1 - self.mu) * var + # normalize and scale + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) + return x * self.scale + self.bias + + +Surgery +^^^^^^^ +Model surgery historically has been a difficult problem in Flax Linen because of two reasons: +1. Due to lazy initialization, its not guaranteed you can replace a submodule with new one. +2. The parameter structure is separate from the module structure, so you manually have to keep + them in sync. + +In Flax NNX, you can replace submodules directly per Python semantics. Since the parameters are +part of the Module structre, they are never out of sync. Below is an example of how you can +implement a LoRA layer and replace a Linear layer of an existing model with it. + +.. codediff:: + :title: Linen, NNX + :sync: + + class LoraLinear(nn.Module): + linear: nn.Dense + rank: int + + @nn.compact + def __call__(self, x: jax.Array): + A = self.param(random.normal, (x.shape[-1], self.rank)) + B = self.param(random.normal, (self.rank, self.linear.features)) + + return self.linear(x) + x @ A @ B + + try: + model = Block(train=True) + model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR + + lora_params = model.linear.init(random.key(1), x) + lora_params['linear'] = params['linear'] + params['linear'] = lora_params + + except AttributeError as e: + pass + + --- + + class LoraParam(nnx.Param): pass + + class LoraLinear(nnx.Module): + def __init__(self, linear, rank, rngs): + self.linear = linear + self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank))) + self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features))) + + def __call__(self, x: jax.Array): + return self.linear(x) + x @ self.A @ self.B + + rngs = nnx.Rngs(0) + model = Block(rngs) + model.linear = LoraLinear(model.linear, rank=5, rngs=rngs) + + + + + + + ... + +As should above, in Linen this doesn't really work in this case because the ``.linear`` submodule +is not available, however the rest of the code gives an idea how the ``params`` structure must be +manually updated. + +Performing arbitrary model surgery is not very easy in Flax Linen, currently the +`intercept_methods `_ +API is the only was to do generic patching of methods but it's not very ergonomic. In NNX, using ``iter_graph`` its very easy +to do generic model surgery, below is an example of replacing all Linear layers in a model with LoRA layers. + +.. testcode:: NNX + + rngs = nnx.Rngs(0) + model = Block(rngs) + + for path, module in nnx.iter_graph(model): + if isinstance(module, nnx.Module): + for name, value in vars(module).items(): + if isinstance(value, nnx.Linear): + setattr(module, name, LoraLinear(value, rank=5, rngs=rngs)) + +Transforms +^^^^^^^^^^ +Flax Linen transforms are very powerful in that they allow fine-grained control over the model's state, +however Linen transforms have the following drawbacks: +1. They expose additional APIs that are not part of JAX. +2. They work on functions with very specific signatures: + * A Module must be the first argument. + * They accepts other Modules as arguments but not as return values. +3. They can only be used inside ``apply``. + +`Flax NNX transforms `_ on the other hand +are intented to be equivalent to JAX transforms with the exception that they can be used on Modules. This +means they have the same API as JAX transforms, can accepts Modules on any argument and Modules can be +returned from them, and they can be used anywhere including the training loop. + +Here is an example of using ``vmap`` with Flax NNX to both create a stack of weights by transforming the +``create_weights`` function which returns some ``Weights``, and to apply the stack of weights to a batch +of inputs individually by transforming the ``vector_dot`` function which takes a ``Weights`` as the first +argument and a batch of inputs as the second argument. + +.. testcode:: NNX + + class Weights(nnx.Module): + def __init__(self, kernel: jax.Array, bias: jax.Array): + self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) + + def create_weights(seed: jax.Array): + return Weights( + kernel=random.uniform(random.key(seed), (2, 3)), + bias=jnp.zeros((3,)), + ) + + def vector_dot(weights: Weights, x: jax.Array): + assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' + assert x.ndim == 1, 'Batch dimensions not allowed' + return x @ weights.kernel + weights.bias + + seeds = jnp.arange(10) + weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds) + + x = jax.random.normal(random.key(1), (10, 2)) + y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x) + +Contrary to Linen transforms, the arguments ``in_axes`` and other APIs do affect how the Module state is transformed. + +Flax NNX transforms can also be used as method decorators, as Module methods are simply +functions that take a Module as the first argument. This means that the previous example can be +rewritten as follows: + +.. testcode:: NNX + + class WeightStack(nnx.Module): + @nnx.vmap(in_axes=(0, 0), out_axes=0) + def __init__(self, seed: jax.Array): + self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3))) + self.bias = nnx.Param(jnp.zeros((3,))) + + @nnx.vmap(in_axes=(0, 0), out_axes=1) + def __call__(self, x: jax.Array): + assert self.kernel.ndim == 2, 'Batch dimensions not allowed' + assert x.ndim == 1, 'Batch dimensions not allowed' + return x @ self.kernel + self.bias + + weights = WeightStack(jnp.arange(10)) + + x = jax.random.normal(random.key(1), (10, 2)) + y = weights(x) \ No newline at end of file