diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb index 60f6d3568a..edf8ef8e0f 100644 --- a/docs_nnx/guides/transforms.ipynb +++ b/docs_nnx/guides/transforms.ipynb @@ -90,7 +90,7 @@ " def __init__(self, kernel: jax.Array, bias: jax.Array):\n", " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", "\n", - "self = Weights(\n", + "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", ")\n", @@ -101,10 +101,10 @@ " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " return x @ weights.kernel + weights.bias\n", "\n", - "y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(self, x)\n", + "y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)\n", "\n", "print(f'{y.shape = }')\n", - "nnx.display(self)" + "nnx.display(weights)" ] }, { @@ -158,8 +158,8 @@ " )\n", "\n", "seeds = jnp.arange(10)\n", - "self = nnx.vmap(create_weights)(seeds)\n", - "nnx.display(self)" + "weights = nnx.vmap(create_weights)(seeds)\n", + "nnx.display(weights)" ] }, { @@ -276,7 +276,7 @@ " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", " self.count = Count(count)\n", "\n", - "self = Weights(\n", + "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", " count=jnp.arange(10),\n", @@ -290,9 +290,9 @@ " return x @ weights.kernel + weights.bias\n", "\n", "\n", - "y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(self, x)\n", + "y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)\n", "\n", - "self.count" + "weights.count" ] }, { @@ -353,7 +353,7 @@ " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", " self.count = Count(count)\n", "\n", - "self = Weights(\n", + "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", " count=jnp.arange(10),\n", @@ -370,9 +370,9 @@ " weights.new_param = weights.kernel # share reference\n", " return y\n", "\n", - "y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(self, x)\n", + "y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)\n", "\n", - "nnx.display(self)" + "nnx.display(weights)" ] }, { @@ -433,7 +433,7 @@ " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", " self.count = Count(count)\n", "\n", - "self = Weights(\n", + "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", " count=jnp.array(0),\n", @@ -448,9 +448,9 @@ " return x @ weights.kernel + weights.bias\n", "\n", "state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count\n", - "y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(self, x)\n", + "y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)\n", "\n", - "self.count" + "weights.count" ] }, { @@ -517,7 +517,7 @@ " self.count = Count(count)\n", " self.rngs = nnx.Rngs(noise=seed)\n", "\n", - "self = Weights(\n", + "weights = Weights(\n", " kernel=random.uniform(random.key(0), (2, 3)),\n", " bias=jnp.zeros((3,)),\n", " count=jnp.array(0),\n", @@ -533,11 +533,11 @@ " return y + random.normal(weights.rngs.noise(), y.shape)\n", "\n", "state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})\n", - "y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)\n", - "y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)\n", + "y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n", + "y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n", "\n", "print(jnp.allclose(y1, y2))\n", - "nnx.display(self)" + "nnx.display(weights)" ] }, { @@ -589,7 +589,7 @@ } ], "source": [ - "self = Weights(\n", + "weights = Weights(\n", " kernel=random.uniform(random.key(0), (2, 3)),\n", " bias=jnp.zeros((3,)),\n", " count=jnp.array(0),\n", @@ -608,11 +608,11 @@ " y = x @ weights.kernel + weights.bias\n", " return y + random.normal(weights.rngs.noise(), y.shape)\n", "\n", - "y1 = noisy_vector_dot(self, x)\n", - "y2 = noisy_vector_dot(self, x)\n", + "y1 = noisy_vector_dot(weights, x)\n", + "y2 = noisy_vector_dot(weights, x)\n", "\n", "print(jnp.allclose(y1, y2))\n", - "nnx.display(self)" + "nnx.display(weights)" ] }, { diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md index 89946be97f..1b185e5006 100644 --- a/docs_nnx/guides/transforms.md +++ b/docs_nnx/guides/transforms.md @@ -44,7 +44,7 @@ class Weights(nnx.Module): def __init__(self, kernel: jax.Array, bias: jax.Array): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) -self = Weights( +weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), ) @@ -55,10 +55,10 @@ def vector_dot(weights: Weights, x: jax.Array): assert x.ndim == 1, 'Batch dimensions not allowed' return x @ weights.kernel + weights.bias -y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(self, x) +y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x) print(f'{y.shape = }') -nnx.display(self) +nnx.display(weights) ``` Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it where a Pytree of arrays. Prefix patterns are also allowed, `in_axes=(0, 0)` would've also worked in this case. @@ -75,8 +75,8 @@ def create_weights(seed: jax.Array): ) seeds = jnp.arange(10) -self = nnx.vmap(create_weights)(seeds) -nnx.display(self) +weights = nnx.vmap(create_weights)(seeds) +nnx.display(weights) ``` ## Transforming Methods @@ -120,7 +120,7 @@ class Weights(nnx.Module): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) self.count = Count(count) -self = Weights( +weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), count=jnp.arange(10), @@ -134,9 +134,9 @@ def stateful_vector_dot(weights: Weights, x: jax.Array): return x @ weights.kernel + weights.bias -y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(self, x) +y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x) -self.count +weights.count ``` After running `stateful_vector_dot` once we verify that the `count` attribute was correctly updated. Because Weights is vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by 1 inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice! @@ -156,7 +156,7 @@ class Weights(nnx.Module): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) self.count = Count(count) -self = Weights( +weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), count=jnp.arange(10), @@ -173,9 +173,9 @@ def crazy_vector_dot(weights: Weights, x: jax.Array): weights.new_param = weights.kernel # share reference return y -y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(self, x) +y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x) -nnx.display(self) +nnx.display(weights) ``` > With great power comes great responsibility. @@ -207,7 +207,7 @@ class Weights(nnx.Module): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) self.count = Count(count) -self = Weights( +weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), count=jnp.array(0), @@ -222,9 +222,9 @@ def stateful_vector_dot(weights: Weights, x: jax.Array): return x @ weights.kernel + weights.bias state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count -y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(self, x) +y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x) -self.count +weights.count ``` Here count is now a scalar since its not being vectorized. Also, note that `StateAxes` can only be used directly on Flax NNX objects, it cannot be used as a prefix for a pytree of objects. @@ -243,7 +243,7 @@ class Weights(nnx.Module): self.count = Count(count) self.rngs = nnx.Rngs(noise=seed) -self = Weights( +weights = Weights( kernel=random.uniform(random.key(0), (2, 3)), bias=jnp.zeros((3,)), count=jnp.array(0), @@ -259,11 +259,11 @@ def noisy_vector_dot(weights: Weights, x: jax.Array): return y + random.normal(weights.rngs.noise(), y.shape) state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None}) -y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x) -y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x) +y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x) +y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x) print(jnp.allclose(y1, y2)) -nnx.display(self) +nnx.display(weights) ``` Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called. @@ -271,7 +271,7 @@ Because `Rngs`'s state is updated in place and automatically propagated by `nnx. In the example above we manually split the random state during construction, this is fine as it makes the intention clear but it also doesn't let us use `Rngs` outside of `vmap` since its state is always split. To solve this we pass an unplit seed and use the `nnx.split_rngs` decorator before `vmap` to split the `RngState` right before each call to the function and then "lower" it back so its usable. ```{code-cell} ipython3 -self = Weights( +weights = Weights( kernel=random.uniform(random.key(0), (2, 3)), bias=jnp.zeros((3,)), count=jnp.array(0), @@ -290,11 +290,11 @@ def noisy_vector_dot(weights: Weights, x: jax.Array): y = x @ weights.kernel + weights.bias return y + random.normal(weights.rngs.noise(), y.shape) -y1 = noisy_vector_dot(self, x) -y2 = noisy_vector_dot(self, x) +y1 = noisy_vector_dot(weights, x) +y2 = noisy_vector_dot(weights, x) print(jnp.allclose(y1, y2)) -nnx.display(self) +nnx.display(weights) ``` ## Consistent aliasing