Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] fix transforms guide #4421

Merged
merged 1 commit into from
Dec 10, 2024
Merged

[nnx] fix transforms guide #4421

merged 1 commit into from
Dec 10, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Dec 6, 2024

What does this PR do?

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -345,7 +345,7 @@ To solve this issue pass all Module as arguments to the functions being transfor

### Consistent aliasing

The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error.
The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be "module" instead of "Module"?

Copy link
Collaborator Author

@cgarciae cgarciae Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use Module in other places so its fine

@cgarciae cgarciae force-pushed the nnx-fix-transforms-guide branch from 9a255ab to 10f3460 Compare December 6, 2024 13:50
@copybara-service copybara-service bot merged commit f09d105 into main Dec 10, 2024
12 of 19 checks passed
@copybara-service copybara-service bot deleted the nnx-fix-transforms-guide branch December 10, 2024 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants