diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb index 4ad1e48af3..d6c758ad37 100644 --- a/docs_nnx/guides/transforms.ipynb +++ b/docs_nnx/guides/transforms.ipynb @@ -684,7 +684,7 @@ "source": [ "### Consistent aliasing\n", "\n", - "The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error." + "The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error." ] }, { @@ -863,7 +863,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md index 0cd7046f3e..95961940e8 100644 --- a/docs_nnx/guides/transforms.md +++ b/docs_nnx/guides/transforms.md @@ -345,7 +345,7 @@ To solve this issue pass all Module as arguments to the functions being transfor ### Consistent aliasing -The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error. +The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error. ```{code-cell} ipython3 class Weights(nnx.Module):