Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] fix transforms guide #4223

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions docs_nnx/guides/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
" def __init__(self, kernel: jax.Array, bias: jax.Array):\n",
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
")\n",
Expand All @@ -101,10 +101,10 @@
" assert x.ndim == 1, 'Batch dimensions not allowed'\n",
" return x @ weights.kernel + weights.bias\n",
"\n",
"y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(self, x)\n",
"y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)\n",
"\n",
"print(f'{y.shape = }')\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -158,8 +158,8 @@
" )\n",
"\n",
"seeds = jnp.arange(10)\n",
"self = nnx.vmap(create_weights)(seeds)\n",
"nnx.display(self)"
"weights = nnx.vmap(create_weights)(seeds)\n",
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -276,7 +276,7 @@
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
" self.count = Count(count)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
" count=jnp.arange(10),\n",
Expand All @@ -290,9 +290,9 @@
" return x @ weights.kernel + weights.bias\n",
"\n",
"\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(self, x)\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)\n",
"\n",
"self.count"
"weights.count"
]
},
{
Expand Down Expand Up @@ -353,7 +353,7 @@
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
" self.count = Count(count)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
" count=jnp.arange(10),\n",
Expand All @@ -370,9 +370,9 @@
" weights.new_param = weights.kernel # share reference\n",
" return y\n",
"\n",
"y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(self, x)\n",
"y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)\n",
"\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -433,7 +433,7 @@
" self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n",
" self.count = Count(count)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (10, 2, 3)),\n",
" bias=jnp.zeros((10, 3)),\n",
" count=jnp.array(0),\n",
Expand All @@ -448,9 +448,9 @@
" return x @ weights.kernel + weights.bias\n",
"\n",
"state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(self, x)\n",
"y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)\n",
"\n",
"self.count"
"weights.count"
]
},
{
Expand Down Expand Up @@ -517,7 +517,7 @@
" self.count = Count(count)\n",
" self.rngs = nnx.Rngs(noise=seed)\n",
"\n",
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (2, 3)),\n",
" bias=jnp.zeros((3,)),\n",
" count=jnp.array(0),\n",
Expand All @@ -533,11 +533,11 @@
" return y + random.normal(weights.rngs.noise(), y.shape)\n",
"\n",
"state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})\n",
"y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)\n",
"y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)\n",
"y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n",
"y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n",
"\n",
"print(jnp.allclose(y1, y2))\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down Expand Up @@ -589,7 +589,7 @@
}
],
"source": [
"self = Weights(\n",
"weights = Weights(\n",
" kernel=random.uniform(random.key(0), (2, 3)),\n",
" bias=jnp.zeros((3,)),\n",
" count=jnp.array(0),\n",
Expand All @@ -608,11 +608,11 @@
" y = x @ weights.kernel + weights.bias\n",
" return y + random.normal(weights.rngs.noise(), y.shape)\n",
"\n",
"y1 = noisy_vector_dot(self, x)\n",
"y2 = noisy_vector_dot(self, x)\n",
"y1 = noisy_vector_dot(weights, x)\n",
"y2 = noisy_vector_dot(weights, x)\n",
"\n",
"print(jnp.allclose(y1, y2))\n",
"nnx.display(self)"
"nnx.display(weights)"
]
},
{
Expand Down
44 changes: 22 additions & 22 deletions docs_nnx/guides/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
)
Expand All @@ -55,10 +55,10 @@ def vector_dot(weights: Weights, x: jax.Array):
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ weights.kernel + weights.bias

y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(self, x)
y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)

print(f'{y.shape = }')
nnx.display(self)
nnx.display(weights)
```

Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it where a Pytree of arrays. Prefix patterns are also allowed, `in_axes=(0, 0)` would've also worked in this case.
Expand All @@ -75,8 +75,8 @@ def create_weights(seed: jax.Array):
)

seeds = jnp.arange(10)
self = nnx.vmap(create_weights)(seeds)
nnx.display(self)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)
```

## Transforming Methods
Expand Down Expand Up @@ -120,7 +120,7 @@ class Weights(nnx.Module):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)

self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
Expand All @@ -134,9 +134,9 @@ def stateful_vector_dot(weights: Weights, x: jax.Array):
return x @ weights.kernel + weights.bias


y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(self, x)
y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)

self.count
weights.count
```

After running `stateful_vector_dot` once we verify that the `count` attribute was correctly updated. Because Weights is vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by 1 inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!
Expand All @@ -156,7 +156,7 @@ class Weights(nnx.Module):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)

self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
Expand All @@ -173,9 +173,9 @@ def crazy_vector_dot(weights: Weights, x: jax.Array):
weights.new_param = weights.kernel # share reference
return y

y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(self, x)
y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)

nnx.display(self)
nnx.display(weights)
```

> With great power comes great responsibility.
Expand Down Expand Up @@ -207,7 +207,7 @@ class Weights(nnx.Module):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)

self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.array(0),
Expand All @@ -222,9 +222,9 @@ def stateful_vector_dot(weights: Weights, x: jax.Array):
return x @ weights.kernel + weights.bias

state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(self, x)
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)

self.count
weights.count
```

Here count is now a scalar since its not being vectorized. Also, note that `StateAxes` can only be used directly on Flax NNX objects, it cannot be used as a prefix for a pytree of objects.
Expand All @@ -243,7 +243,7 @@ class Weights(nnx.Module):
self.count = Count(count)
self.rngs = nnx.Rngs(noise=seed)

self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
Expand All @@ -259,19 +259,19 @@ def noisy_vector_dot(weights: Weights, x: jax.Array):
return y + random.normal(weights.rngs.noise(), y.shape)

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(self, x)
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(self)
nnx.display(weights)
```

Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called.

In the example above we manually split the random state during construction, this is fine as it makes the intention clear but it also doesn't let us use `Rngs` outside of `vmap` since its state is always split. To solve this we pass an unplit seed and use the `nnx.split_rngs` decorator before `vmap` to split the `RngState` right before each call to the function and then "lower" it back so its usable.

```{code-cell} ipython3
self = Weights(
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
Expand All @@ -290,11 +290,11 @@ def noisy_vector_dot(weights: Weights, x: jax.Array):
y = x @ weights.kernel + weights.bias
return y + random.normal(weights.rngs.noise(), y.shape)

y1 = noisy_vector_dot(self, x)
y2 = noisy_vector_dot(self, x)
y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(self)
nnx.display(weights)
```

## Consistent aliasing
Expand Down
Loading