Skip to content

Commit

Permalink
[nnx] experimental transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Aug 16, 2024
1 parent 71b5a46 commit faf9f67
Show file tree
Hide file tree
Showing 53 changed files with 6,935 additions and 3,047 deletions.
8 changes: 8 additions & 0 deletions docs/api_reference/flax.nnx/experimental.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
experimental
------------------------

.. automodule:: flax.nnx.experimental
.. currentmodule:: flax.nnx.experimental

.. autoclass:: StateAxes
.. autofunction:: vmap
1 change: 1 addition & 0 deletions docs/api_reference/flax.nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/i
helpers
visualization
filterlib
experimental

3 changes: 1 addition & 2 deletions docs/nnx/haiku_linen_vs_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1209,8 +1209,7 @@ in ``__init__`` to scan over the sequence.
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, state_axes={},
in_axes=1, out_axes=1
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=1
)(carry, self.cell, x)

return y
Expand Down
5 changes: 0 additions & 5 deletions flax/nnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ building and experimenting with neural networks as easy and intuitive as possibl
* **Compatible**: NNX allows functionalizing Module state, making it possible to directly use JAX
transformations when needed.

> [!NOTE]
> NNX is currently in an experimental state and is subject to change. Linen is still the
recommended option for large-scale projects. Feedback and contributions are welcome!


## What does NNX look like?

NNX removes most of the friction from building and training neural networks in JAX. It provides
Expand Down
38 changes: 26 additions & 12 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from .nnx import bridge as bridge
from .nnx import traversals as traversals
from .nnx import filterlib as filterlib
from .nnx import transforms as transforms
from .nnx import extract as extract
from .nnx.filterlib import WithTag as WithTag
from .nnx.filterlib import PathContains as PathContains
from .nnx.filterlib import OfType as OfType
Expand Down Expand Up @@ -55,6 +57,10 @@
from .nnx.graph import graphdef as graphdef
from .nnx.graph import iter_graph as iter_graph
from .nnx.graph import call as call
from .nnx.graph import SplitContext as SplitContext
from .nnx.graph import split_context as split_context
from .nnx.graph import MergeContext as MergeContext
from .nnx.graph import merge_context as merge_context
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
Expand Down Expand Up @@ -106,6 +112,8 @@
from .nnx.rnglib import ForkStates as ForkStates
from .nnx.rnglib import fork as fork
from .nnx.rnglib import reseed as reseed
from .nnx.rnglib import split_rngs as split_rngs
from .nnx.rnglib import restore_rngs as restore_rngs
from .nnx.spmd import PARTITION_NAME as PARTITION_NAME
from .nnx.spmd import get_partition_spec as get_partition_spec
from .nnx.spmd import get_named_sharding as get_named_sharding
Expand All @@ -121,20 +129,25 @@
from .nnx.training.metrics import Metric as Metric
from .nnx.training.metrics import MultiMetric as MultiMetric
from .nnx.training.optimizer import Optimizer as Optimizer
from .nnx.transforms.transforms import Jit as Jit
from .nnx.transforms.transforms import Remat as Remat
from .nnx.transforms.looping import Scan as Scan
from .nnx.transforms.parallelization import Vmap as Vmap
from .nnx.transforms.parallelization import Pmap as Pmap
from .nnx.transforms.transforms import grad as grad
from .nnx.transforms.transforms import jit as jit
from .nnx.transforms.transforms import remat as remat
from .nnx.transforms.looping import scan as scan
from .nnx.transforms.transforms import value_and_grad as value_and_grad
from .nnx.transforms.parallelization import vmap as vmap
from .nnx.transforms.parallelization import pmap as pmap
from .nnx.transforms.deprecated import Jit as Jit
from .nnx.transforms.deprecated import Remat as Remat
from .nnx.transforms.deprecated import Scan as Scan
from .nnx.transforms.deprecated import Vmap as Vmap
from .nnx.transforms.deprecated import Pmap as Pmap
from .nnx.transforms.autodiff import DiffState as DiffState
from .nnx.transforms.autodiff import grad as grad
from .nnx.transforms.autodiff import value_and_grad as value_and_grad
from .nnx.transforms.autodiff import custom_vjp as custom_vjp
from .nnx.transforms.autodiff import remat as remat
from .nnx.transforms.compilation import jit as jit
from .nnx.transforms.compilation import StateSharding as StateSharding
from .nnx.transforms.iteration import Carry as Carry
from .nnx.transforms.iteration import scan as scan
from .nnx.transforms.iteration import vmap as vmap
from .nnx.transforms.iteration import pmap as pmap
from .nnx.transforms.transforms import eval_shape as eval_shape
from .nnx.transforms.transforms import cond as cond
from .nnx.transforms.iteration import StateAxes as StateAxes
from .nnx.variables import EMPTY as EMPTY
from .nnx.variables import A as A
from .nnx.variables import BatchStat as BatchStat
Expand All @@ -146,3 +159,4 @@
from .nnx.variables import VariableMetadata as VariableMetadata
from .nnx.variables import with_metadata as with_metadata
from .nnx.visualization import display as display
from .nnx.extract import to_tree, from_tree, TreeNode
8 changes: 4 additions & 4 deletions flax/nnx/docs/quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@
}
],
"source": [
"jax.tree_util.tree_map(jnp.shape, model.extract(nnx.Param))"
"jax.tree.map(jnp.shape, model.extract(nnx.Param))"
]
},
{
Expand All @@ -279,7 +279,7 @@
"\n",
"For pedagogical purposes, we first train the model in eager mode. This will be uselful to take a look at some of NNX's features, its be more approachable for new users, and great for debugging, but it is not the recommended way to train models in JAX.\n",
"\n",
"Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree_map` operation. Finally, we will update the model's parameters using the `.update_state` method."
"Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree.map` operation. Finally, we will update the model's parameters using the `.update_state` method."
]
},
{
Expand Down Expand Up @@ -318,7 +318,7 @@
"\n",
" loss, grads = nnx.value_and_grad(loss_fn, wrt=\"params\")(model)\n",
" params = model.extract(\"params\")\n",
" params = jax.tree_util.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n",
" params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n",
"\n",
" model.update(params)\n",
" print(f\"Step {step}: loss={loss:.4f}\")"
Expand Down Expand Up @@ -354,7 +354,7 @@
" return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n",
"\n",
" loss, grads = jax.value_and_grad(loss_fn)(params)\n",
" params = jax.tree_util.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n",
" params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n",
"\n",
" return loss, params"
]
Expand Down
6 changes: 3 additions & 3 deletions flax/nnx/docs/tiny_nnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@
"y = module(x, train=True, rngs=Rngs(random.key(1)))\n",
"\n",
"state, graphdef = module.split()\n",
"print(\"state =\", jax.tree_util.tree_map(jnp.shape, state))\n",
"print(\"state =\", jax.tree.map(jnp.shape, state))\n",
"print(\"graphdef =\", graphdef)"
]
},
Expand Down Expand Up @@ -442,8 +442,8 @@
"# merge\n",
"state = State({**params, **batch_stats})\n",
"\n",
"print(\"params =\", jax.tree_util.tree_map(jnp.shape, params))\n",
"print(\"batch_stats =\", jax.tree_util.tree_map(jnp.shape, batch_stats))"
"print(\"params =\", jax.tree.map(jnp.shape, params))\n",
"print(\"batch_stats =\", jax.tree.map(jnp.shape, batch_stats))"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/docs/why.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@
"\n",
"print(f'{y.shape = }')\n",
"print(f'{ensemble.models.count = }')\n",
"print(f'state = {jax.tree_util.tree_map(jnp.shape, ensemble.get_state())}')"
"print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}')"
]
},
{
Expand Down Expand Up @@ -752,7 +752,7 @@
" rngs=nnx.Rngs(0))\n",
"\n",
"graphdef, state = model.split()\n",
"jax.tree_util.tree_map(jnp.shape, state)"
"jax.tree.map(jnp.shape, state)"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/docs/why.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ y = ensemble(x)
print(f'{y.shape = }')
print(f'{ensemble.models.count = }')
print(f'state = {jax.tree_util.tree_map(jnp.shape, ensemble.get_state())}')
print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}')
```

#### Convenience lifted transforms
Expand Down Expand Up @@ -405,5 +405,5 @@ model = Example(in_filters=3,
rngs=nnx.Rngs(0))
graphdef, state = model.split()
jax.tree_util.tree_map(jnp.shape, state)
jax.tree.map(jnp.shape, state)
```
2 changes: 1 addition & 1 deletion flax/nnx/examples/gemma/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
def load_and_format_params(path: str) -> Params:
"""Loads parameters and formats them for compatibility."""
params = load_params(path)
param_state = jax.tree_util.tree_map(jnp.array, params)
param_state = jax.tree.map(jnp.array, params)
remapped_params = param_remapper(param_state)
nested_params = nest_params(remapped_params)
return nested_params
Expand Down
24 changes: 10 additions & 14 deletions flax/nnx/examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,10 @@ def per_host_sum_pmap(in_tree):
host_psum = jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i', devices=devices)

def pre_pmap(xs):
return jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs
)
return jax.tree.map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)

def post_pmap(xs):
return jax.tree_util.tree_map(lambda x: x[0], xs)
return jax.tree.map(lambda x: x[0], xs)

return post_pmap(host_psum(pre_pmap(in_tree)))

Expand All @@ -331,13 +329,13 @@ def evaluate(
eval_metrics = []
eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types
for _, eval_batch in zip(range(num_eval_steps), eval_iter):
eval_batch = jax.tree_util.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access
eval_batch = jax.tree.map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access
metrics = jit_eval_step(state.params, eval_batch, state.graphdef)
eval_metrics.append(metrics)
eval_metrics = common_utils.stack_forest(eval_metrics)
eval_metrics_sums = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics)
eval_denominator = eval_metrics_sums.pop('denominator')
eval_summary = jax.tree_util.tree_map(
eval_summary = jax.tree.map(
lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop
eval_metrics_sums,
)
Expand Down Expand Up @@ -368,7 +366,7 @@ def generate_prediction(
cur_pred_batch_size = pred_batch.shape[0]
if cur_pred_batch_size % n_devices:
padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices)
pred_batch = jax.tree_util.tree_map(
pred_batch = jax.tree.map(
lambda x: pad_examples(x, padded_size), pred_batch
) # pylint: disable=cell-var-from-loop
pred_batch = common_utils.shard(pred_batch)
Expand Down Expand Up @@ -538,7 +536,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
predict_step,
in_axes=(
0,
jax.tree_util.tree_map(lambda x: None, state.params),
jax.tree.map(lambda x: None, state.params),
0,
None,
None,
Expand Down Expand Up @@ -582,7 +580,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation('train', step_num=step):
batch = next(train_iter)
batch = jax.tree_util.tree_map(lambda x: jnp.asarray(x), batch)
batch = jax.tree.map(lambda x: jnp.asarray(x), batch)
state, metrics = jit_train_step(
state, batch, learning_rate_fn, 0.0, dropout_rngs
)
Expand All @@ -599,11 +597,9 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
logging.info('Gathering training metrics.')
train_metrics = common_utils.stack_forest(train_metrics)
lr = train_metrics.pop('learning_rate').mean()
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
metrics_sums = jax.tree.map(jnp.sum, train_metrics)
denominator = metrics_sums.pop('denominator')
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
summary = jax.tree.map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary['learning_rate'] = lr
summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4)
summary = {'train_' + k: v for k, v in summary.items()}
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/examples/lm1b/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def setup_initial_state(
state = TrainState.create(
apply_fn=graphdef.apply, params=params, tx=tx, graphdef=graphdef
)
state = jax.tree_util.tree_map(_to_array, state)
state = jax.tree.map(_to_array, state)
state_spec = nnx.get_partition_spec(state)
state = jax.lax.with_sharding_constraint(state, state_spec)

Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/examples/toy_examples/01_functional_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def loss_fn(params):

grad, counts = jax.grad(loss_fn, has_aux=True)(params)
# |-------- sgd ---------|
params = jax.tree_util.tree_map(lambda w, g: w - 0.1 * g, params, grad)
params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grad)

return params, counts

Expand Down
4 changes: 1 addition & 3 deletions flax/nnx/examples/toy_examples/02_lifted_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)

# |--default--|
grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model)
# sgd update
grads: nnx.State = nnx.grad(loss_fn)(model)
optimizer.update(grads)


Expand Down
7 changes: 4 additions & 3 deletions flax/nnx/examples/toy_examples/06_scan_over_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -40,13 +39,15 @@ class ScanMLP(nnx.Module):
def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs):
self.n_layers = n_layers

@partial(nnx.vmap, axis_size=n_layers)
@nnx.split_rngs(splits=n_layers)
@nnx.vmap(axis_size=n_layers)
def create_block(rngs: nnx.Rngs):
return Block(dim, rngs=rngs)

self.layers = create_block(rngs)

def __call__(self, x: jax.Array) -> jax.Array:
@nnx.split_rngs(splits=self.n_layers)
@nnx.scan
def scan_fn(x: jax.Array, block: Block):
x = block(x)
Expand All @@ -62,5 +63,5 @@ def scan_fn(x: jax.Array, block: Block):
x = jnp.ones((3, 10))
y = model(x)

print(jax.tree_util.tree_map(jnp.shape, nnx.state(model)))
print(jax.tree.map(jnp.shape, nnx.state(model)))
print(y.shape)
6 changes: 2 additions & 4 deletions flax/nnx/examples/toy_examples/09_parameter_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def __call__(self, x):

print(
'trainable_params =',
jax.tree_util.tree_map(jax.numpy.shape, trainable_params),
)
print(
'non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable)
jax.tree.map(jax.numpy.shape, trainable_params),
)
print('non_trainable = ', jax.tree.map(jax.numpy.shape, non_trainable))
2 changes: 1 addition & 1 deletion flax/nnx/examples/toy_examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
matplotlib>=3.7.1
datasets>=2.12.0"
datasets>=2.12.0
2 changes: 1 addition & 1 deletion flax/nnx/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class ToLinen(linen.Module):
>>> variables.keys()
dict_keys(['nnx', 'params'])
>>> type(variables['nnx']['graphdef'])
<class 'flax.nnx.nnx.graph.GraphDef'>
<class 'flax.nnx.nnx.graph.NodeDef'>
Args:
nnx_class: The NNX Module class (not instance!).
Expand Down
Loading

0 comments on commit faf9f67

Please sign in to comment.