-
Notifications
You must be signed in to change notification settings - Fork 659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add transforms guide #4197
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
f0d5f35
to
f46c9ea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this guide! Super helpful and cool. Just a few nits on wordings.
docs_nnx/guides/transforms.md
Outdated
+++ | ||
|
||
### Graph updates propagation | ||
JAX models inputs to transformations as trees, Flax NNX models inputs as graphs to allow for sharing references. However, to express most of Python's object model Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local (updates to globals inside transforms are not supported). This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing Variables between objects, etc. The sky is the limit! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX models inputs to transformations as trees, Flax NNX models inputs as graphs to allow for sharing references.
A bit hard to read - maybe:
JAX transformations see inputs as trees of arrays, and Flax NNX see inputs as graphs of Python references.
However, to express most of Python's object model Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local (updates to globals inside transforms are not supported).
This line also a bit verbose? Maybe just:
Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the input graph (updates to globals inside transforms are not supported).
) | ||
x = jax.random.normal(random.key(1), (10, 2)) | ||
|
||
def crazy_vector_dot(weights: Weights, x: jax.Array): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not here, but I was hoping to see an example of transforming and using an nnx.Module
method to showcase that it works and can be a natural pattern for users to take, since most transforms happen not at top level but in-between two layer definitions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a good point. I'll add a variation of the first example using vmap
over __call__
so users know that its possible.
docs_nnx/guides/transforms.md
Outdated
> With great power comes great responsibility. | ||
> <br> \- Uncle Ben | ||
|
||
While this feature is very powerful, it must be used with care as it can clash with JAX's underlying assumptions for certain transformations. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, changing the graph structure inside a `nnx.jit`-ed function cause continuous recompilations and performance degradation, `scan` on the other hand only allows a fixed `carry` structure, so adding/removing substates declared as carry will cause an error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
For example,
jit
expects the structure of the inputs to be stable in order to cache the compiled function, changing the graph structure inside annx.jit
-ed function cause continuous recompilations and performance degradation,scan
on the other hand only allows a fixedcarry
structure, so adding/removing substates declared as carry will cause an error.
For example, jit
expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside a nnx.jit
-ed function cause continuous recompilations and performance degradation. On the other hand, scan
only allows a fixed carry
structure, so adding/removing substates declared as carry will cause an error.
) | ||
x = jax.random.normal(random.key(1), (10, 2)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably better to only call vmap once when only one call is needed, to avoid confusion. Same for the example below.
state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
@nnx.vmap(in_axes=(state_axes, 0), out_axes=1)
def stateful_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
return x @ weights.kernel + weights.bias
y = stateful_vector_dot(weights, x)
y = stateful_vector_dot(weights, x)
docs_nnx/guides/transforms.md
Outdated
+++ | ||
|
||
### Random State | ||
In Flax NNX random state is just regular state. This means that its stored inside Modules that need it and its treated as any other type of state. This is a simplification over Flax Linen where random state was handled by a separate mechanism. In practice Modules usually keep that need random state simply need a references to a `Rngs` object that is passed to them during initialization, and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice Modules usually keep that need random state simply need a references to a Rngs object that is passed to them during initialization, and use it to generate a unique key for each random operation.
What about:
In practice Modules simply need to keep a reference to a Rngs object that is passed to them during initialization, and use it to generate a unique key for each random operation.
f46c9ea
to
59acf38
Compare
59acf38
to
fb1a9cc
Compare
Thanks @IvyZX for the detailed feedback. I've integrated all the suggestions. |
What does this PR do?
Adds the
Transforms
guide.Preview