diff --git a/docs/api_reference/flax.nnx/experimental.rst b/docs/api_reference/flax.nnx/experimental.rst new file mode 100644 index 0000000000..73140d9bd2 --- /dev/null +++ b/docs/api_reference/flax.nnx/experimental.rst @@ -0,0 +1,8 @@ +experimental +------------------------ + +.. automodule:: flax.nnx.experimental +.. currentmodule:: flax.nnx.experimental + +.. autoclass:: StateAxes +.. autofunction:: vmap \ No newline at end of file diff --git a/docs/api_reference/flax.nnx/index.rst b/docs/api_reference/flax.nnx/index.rst index 98b05093f7..8378cd7db5 100644 --- a/docs/api_reference/flax.nnx/index.rst +++ b/docs/api_reference/flax.nnx/index.rst @@ -18,4 +18,5 @@ Experimental API. See the `NNX page [!NOTE] -> NNX is currently in an experimental state and is subject to change. Linen is still the - recommended option for large-scale projects. Feedback and contributions are welcome! - - ## What does NNX look like? NNX removes most of the friction from building and training neural networks in JAX. It provides diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 523093ddee..973e8ed6d8 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -25,6 +25,8 @@ from .nnx import bridge as bridge from .nnx import traversals as traversals from .nnx import filterlib as filterlib +from .nnx import transforms as transforms +from .nnx import extract as extract from .nnx.filterlib import WithTag as WithTag from .nnx.filterlib import PathContains as PathContains from .nnx.filterlib import OfType as OfType @@ -55,6 +57,10 @@ from .nnx.graph import graphdef as graphdef from .nnx.graph import iter_graph as iter_graph from .nnx.graph import call as call +from .nnx.graph import SplitContext as SplitContext +from .nnx.graph import split_context as split_context +from .nnx.graph import MergeContext as MergeContext +from .nnx.graph import merge_context as merge_context from .nnx.nn import initializers as initializers from .nnx.nn.activations import celu as celu from .nnx.nn.activations import elu as elu @@ -106,6 +112,8 @@ from .nnx.rnglib import ForkStates as ForkStates from .nnx.rnglib import fork as fork from .nnx.rnglib import reseed as reseed +from .nnx.rnglib import split_rngs as split_rngs +from .nnx.rnglib import restore_rngs as restore_rngs from .nnx.spmd import PARTITION_NAME as PARTITION_NAME from .nnx.spmd import get_partition_spec as get_partition_spec from .nnx.spmd import get_named_sharding as get_named_sharding @@ -120,20 +128,25 @@ from .nnx.training.metrics import Metric as Metric from .nnx.training.metrics import MultiMetric as MultiMetric from .nnx.training.optimizer import Optimizer as Optimizer -from .nnx.transforms.transforms import Jit as Jit -from .nnx.transforms.transforms import Remat as Remat -from .nnx.transforms.looping import Scan as Scan -from .nnx.transforms.parallelization import Vmap as Vmap -from .nnx.transforms.parallelization import Pmap as Pmap -from .nnx.transforms.transforms import grad as grad -from .nnx.transforms.transforms import jit as jit -from .nnx.transforms.transforms import remat as remat -from .nnx.transforms.looping import scan as scan -from .nnx.transforms.transforms import value_and_grad as value_and_grad -from .nnx.transforms.parallelization import vmap as vmap -from .nnx.transforms.parallelization import pmap as pmap +from .nnx.transforms.deprecated import Jit as Jit +from .nnx.transforms.deprecated import Remat as Remat +from .nnx.transforms.deprecated import Scan as Scan +from .nnx.transforms.deprecated import Vmap as Vmap +from .nnx.transforms.deprecated import Pmap as Pmap +from .nnx.transforms.autodiff import DiffState as DiffState +from .nnx.transforms.autodiff import grad as grad +from .nnx.transforms.autodiff import value_and_grad as value_and_grad +from .nnx.transforms.autodiff import custom_vjp as custom_vjp +from .nnx.transforms.autodiff import remat as remat +from .nnx.transforms.compilation import jit as jit +from .nnx.transforms.compilation import StateSharding as StateSharding +from .nnx.transforms.iteration import Carry as Carry +from .nnx.transforms.iteration import scan as scan +from .nnx.transforms.iteration import vmap as vmap +from .nnx.transforms.iteration import pmap as pmap from .nnx.transforms.transforms import eval_shape as eval_shape from .nnx.transforms.transforms import cond as cond +from .nnx.transforms.iteration import StateAxes as StateAxes from .nnx.variables import EMPTY as EMPTY from .nnx.variables import A as A from .nnx.variables import BatchStat as BatchStat @@ -145,3 +158,4 @@ from .nnx.variables import VariableMetadata as VariableMetadata from .nnx.variables import with_metadata as with_metadata from .nnx.visualization import display as display +from .nnx.extract import to_tree, from_tree, TreeNode diff --git a/flax/nnx/docs/quick_start.ipynb b/flax/nnx/docs/quick_start.ipynb index df64361b43..1c8f297726 100644 --- a/flax/nnx/docs/quick_start.ipynb +++ b/flax/nnx/docs/quick_start.ipynb @@ -267,7 +267,7 @@ } ], "source": [ - "jax.tree_util.tree_map(jnp.shape, model.extract(nnx.Param))" + "jax.tree.map(jnp.shape, model.extract(nnx.Param))" ] }, { @@ -279,7 +279,7 @@ "\n", "For pedagogical purposes, we first train the model in eager mode. This will be uselful to take a look at some of NNX's features, its be more approachable for new users, and great for debugging, but it is not the recommended way to train models in JAX.\n", "\n", - "Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree_map` operation. Finally, we will update the model's parameters using the `.update_state` method." + "Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree.map` operation. Finally, we will update the model's parameters using the `.update_state` method." ] }, { @@ -318,7 +318,7 @@ "\n", " loss, grads = nnx.value_and_grad(loss_fn, wrt=\"params\")(model)\n", " params = model.extract(\"params\")\n", - " params = jax.tree_util.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n", + " params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n", "\n", " model.update(params)\n", " print(f\"Step {step}: loss={loss:.4f}\")" @@ -354,7 +354,7 @@ " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", "\n", " loss, grads = jax.value_and_grad(loss_fn)(params)\n", - " params = jax.tree_util.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n", + " params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n", "\n", " return loss, params" ] diff --git a/flax/nnx/docs/tiny_nnx.ipynb b/flax/nnx/docs/tiny_nnx.ipynb index 3fe5bf611a..387b95b8d1 100644 --- a/flax/nnx/docs/tiny_nnx.ipynb +++ b/flax/nnx/docs/tiny_nnx.ipynb @@ -399,7 +399,7 @@ "y = module(x, train=True, rngs=Rngs(random.key(1)))\n", "\n", "state, graphdef = module.split()\n", - "print(\"state =\", jax.tree_util.tree_map(jnp.shape, state))\n", + "print(\"state =\", jax.tree.map(jnp.shape, state))\n", "print(\"graphdef =\", graphdef)" ] }, @@ -442,8 +442,8 @@ "# merge\n", "state = State({**params, **batch_stats})\n", "\n", - "print(\"params =\", jax.tree_util.tree_map(jnp.shape, params))\n", - "print(\"batch_stats =\", jax.tree_util.tree_map(jnp.shape, batch_stats))" + "print(\"params =\", jax.tree.map(jnp.shape, params))\n", + "print(\"batch_stats =\", jax.tree.map(jnp.shape, batch_stats))" ] } ], diff --git a/flax/nnx/docs/why.ipynb b/flax/nnx/docs/why.ipynb index 46caf8c4e8..d38fe6c809 100644 --- a/flax/nnx/docs/why.ipynb +++ b/flax/nnx/docs/why.ipynb @@ -439,7 +439,7 @@ "\n", "print(f'{y.shape = }')\n", "print(f'{ensemble.models.count = }')\n", - "print(f'state = {jax.tree_util.tree_map(jnp.shape, ensemble.get_state())}')" + "print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}')" ] }, { @@ -752,7 +752,7 @@ " rngs=nnx.Rngs(0))\n", "\n", "graphdef, state = model.split()\n", - "jax.tree_util.tree_map(jnp.shape, state)" + "jax.tree.map(jnp.shape, state)" ] } ], diff --git a/flax/nnx/docs/why.md b/flax/nnx/docs/why.md index 07142c0f49..b080319be7 100644 --- a/flax/nnx/docs/why.md +++ b/flax/nnx/docs/why.md @@ -243,7 +243,7 @@ y = ensemble(x) print(f'{y.shape = }') print(f'{ensemble.models.count = }') -print(f'state = {jax.tree_util.tree_map(jnp.shape, ensemble.get_state())}') +print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}') ``` #### Convenience lifted transforms @@ -405,5 +405,5 @@ model = Example(in_filters=3, rngs=nnx.Rngs(0)) graphdef, state = model.split() -jax.tree_util.tree_map(jnp.shape, state) +jax.tree.map(jnp.shape, state) ``` diff --git a/flax/nnx/examples/gemma/params.py b/flax/nnx/examples/gemma/params.py index 4072dd3e5e..da3b6ec5fb 100644 --- a/flax/nnx/examples/gemma/params.py +++ b/flax/nnx/examples/gemma/params.py @@ -42,7 +42,7 @@ def load_and_format_params(path: str) -> Params: """Loads parameters and formats them for compatibility.""" params = load_params(path) - param_state = jax.tree_util.tree_map(jnp.array, params) + param_state = jax.tree.map(jnp.array, params) remapped_params = param_remapper(param_state) nested_params = nest_params(remapped_params) return nested_params diff --git a/flax/nnx/examples/lm1b/train.py b/flax/nnx/examples/lm1b/train.py index a137b9da12..5510b8de16 100644 --- a/flax/nnx/examples/lm1b/train.py +++ b/flax/nnx/examples/lm1b/train.py @@ -303,12 +303,10 @@ def per_host_sum_pmap(in_tree): host_psum = jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i', devices=devices) def pre_pmap(xs): - return jax.tree_util.tree_map( - lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs - ) + return jax.tree.map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs) def post_pmap(xs): - return jax.tree_util.tree_map(lambda x: x[0], xs) + return jax.tree.map(lambda x: x[0], xs) return post_pmap(host_psum(pre_pmap(in_tree))) @@ -331,13 +329,13 @@ def evaluate( eval_metrics = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): - eval_batch = jax.tree_util.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access + eval_batch = jax.tree.map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access metrics = jit_eval_step(state.params, eval_batch, state.graphdef) eval_metrics.append(metrics) eval_metrics = common_utils.stack_forest(eval_metrics) - eval_metrics_sums = jax.tree_util.tree_map(jnp.sum, eval_metrics) + eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') - eval_summary = jax.tree_util.tree_map( + eval_summary = jax.tree.map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums, ) @@ -368,7 +366,7 @@ def generate_prediction( cur_pred_batch_size = pred_batch.shape[0] if cur_pred_batch_size % n_devices: padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) - pred_batch = jax.tree_util.tree_map( + pred_batch = jax.tree.map( lambda x: pad_examples(x, padded_size), pred_batch ) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) @@ -538,7 +536,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): predict_step, in_axes=( 0, - jax.tree_util.tree_map(lambda x: None, state.params), + jax.tree.map(lambda x: None, state.params), 0, None, None, @@ -582,7 +580,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = next(train_iter) - batch = jax.tree_util.tree_map(lambda x: jnp.asarray(x), batch) + batch = jax.tree.map(lambda x: jnp.asarray(x), batch) state, metrics = jit_train_step( state, batch, learning_rate_fn, 0.0, dropout_rngs ) @@ -599,11 +597,9 @@ def constructor(config: models.TransformerConfig, key: jax.Array): logging.info('Gathering training metrics.') train_metrics = common_utils.stack_forest(train_metrics) lr = train_metrics.pop('learning_rate').mean() - metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) + metrics_sums = jax.tree.map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') - summary = jax.tree_util.tree_map( - lambda x: x / denominator, metrics_sums - ) # pylint: disable=cell-var-from-loop + summary = jax.tree.map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4) summary = {'train_' + k: v for k, v in summary.items()} diff --git a/flax/nnx/examples/lm1b/utils.py b/flax/nnx/examples/lm1b/utils.py index b18b6e4691..a9a3f8ce5c 100644 --- a/flax/nnx/examples/lm1b/utils.py +++ b/flax/nnx/examples/lm1b/utils.py @@ -161,7 +161,7 @@ def setup_initial_state( state = TrainState.create( apply_fn=graphdef.apply, params=params, tx=tx, graphdef=graphdef ) - state = jax.tree_util.tree_map(_to_array, state) + state = jax.tree.map(_to_array, state) state_spec = nnx.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) diff --git a/flax/nnx/examples/toy_examples/01_functional_api.py b/flax/nnx/examples/toy_examples/01_functional_api.py index 8f90a24ef6..e790da73d6 100644 --- a/flax/nnx/examples/toy_examples/01_functional_api.py +++ b/flax/nnx/examples/toy_examples/01_functional_api.py @@ -75,7 +75,7 @@ def loss_fn(params): grad, counts = jax.grad(loss_fn, has_aux=True)(params) # |-------- sgd ---------| - params = jax.tree_util.tree_map(lambda w, g: w - 0.1 * g, params, grad) + params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grad) return params, counts diff --git a/flax/nnx/examples/toy_examples/02_lifted_transforms.py b/flax/nnx/examples/toy_examples/02_lifted_transforms.py index bb2238f7a2..9fef3adf26 100644 --- a/flax/nnx/examples/toy_examples/02_lifted_transforms.py +++ b/flax/nnx/examples/toy_examples/02_lifted_transforms.py @@ -71,9 +71,7 @@ def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) - # |--default--| - grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) - # sgd update + grads: nnx.State = nnx.grad(loss_fn)(model) optimizer.update(grads) diff --git a/flax/nnx/examples/toy_examples/06_scan_over_layers.py b/flax/nnx/examples/toy_examples/06_scan_over_layers.py index ad2b2edcea..d8daa71c2c 100644 --- a/flax/nnx/examples/toy_examples/06_scan_over_layers.py +++ b/flax/nnx/examples/toy_examples/06_scan_over_layers.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import jax import jax.numpy as jnp @@ -40,13 +39,15 @@ class ScanMLP(nnx.Module): def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): self.n_layers = n_layers - @partial(nnx.vmap, axis_size=n_layers) + @nnx.split_rngs(splits=n_layers) + @nnx.vmap(axis_size=n_layers) def create_block(rngs: nnx.Rngs): return Block(dim, rngs=rngs) self.layers = create_block(rngs) def __call__(self, x: jax.Array) -> jax.Array: + @nnx.split_rngs(splits=self.n_layers) @nnx.scan def scan_fn(x: jax.Array, block: Block): x = block(x) @@ -62,5 +63,5 @@ def scan_fn(x: jax.Array, block: Block): x = jnp.ones((3, 10)) y = model(x) -print(jax.tree_util.tree_map(jnp.shape, nnx.state(model))) +print(jax.tree.map(jnp.shape, nnx.state(model))) print(y.shape) diff --git a/flax/nnx/examples/toy_examples/09_parameter_surgery.py b/flax/nnx/examples/toy_examples/09_parameter_surgery.py index 11a785aaa6..fd636ec073 100644 --- a/flax/nnx/examples/toy_examples/09_parameter_surgery.py +++ b/flax/nnx/examples/toy_examples/09_parameter_surgery.py @@ -54,8 +54,6 @@ def __call__(self, x): print( 'trainable_params =', - jax.tree_util.tree_map(jax.numpy.shape, trainable_params), -) -print( - 'non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable) + jax.tree.map(jax.numpy.shape, trainable_params), ) +print('non_trainable = ', jax.tree.map(jax.numpy.shape, non_trainable)) diff --git a/flax/nnx/examples/toy_examples/requirements.txt b/flax/nnx/examples/toy_examples/requirements.txt index e44f155e47..64dd014d6c 100644 --- a/flax/nnx/examples/toy_examples/requirements.txt +++ b/flax/nnx/examples/toy_examples/requirements.txt @@ -1,2 +1,2 @@ matplotlib>=3.7.1 -datasets>=2.12.0" \ No newline at end of file +datasets>=2.12.0 \ No newline at end of file diff --git a/flax/nnx/nnx/extract.py b/flax/nnx/nnx/extract.py index b19e036f4d..890ed60798 100644 --- a/flax/nnx/nnx/extract.py +++ b/flax/nnx/nnx/extract.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import contextlib import dataclasses import threading import typing as tp import jax -from jax._src.tree_util import broadcast_prefix +# from jax._src.tree_util import broadcast_prefix from flax import struct from flax.nnx.nnx.object import Object -from flax.nnx.nnx.state import State from flax.typing import PathParts from flax.nnx.nnx import graph @@ -34,7 +32,6 @@ class Missing: MISSING = Missing() A = tp.TypeVar('A') -E = tp.TypeVar('E', bound='Extractable') Index = int KeyEntry = tp.TypeVar('KeyEntry', bound=tp.Hashable) KeyPath = tuple[KeyEntry, ...] @@ -42,30 +39,10 @@ class Missing: Leaf = tp.Any -class Extractable(abc.ABC): - @property - @abc.abstractmethod - def index(self) -> Index: ... - - -class ExtractableStates(Extractable): - @property - @abc.abstractmethod - def states(self) -> tp.Iterable[State]: ... - - @property - @abc.abstractmethod - def graphdef(self) -> graph.GraphDef[tp.Any]: ... - - -class ExtractionIndex(struct.PyTreeNode, Extractable): +class ExtractionIndex(struct.PyTreeNode): """Index of a graph node in a Pytree structure.""" - _index: Index = struct.field(pytree_node=False) - - @property - def index(self) -> Index: - return self._index + index: Index = struct.field(pytree_node=False) @tp.overload @@ -75,8 +52,6 @@ def extract_graph_nodes( *, validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None, ) -> tuple[A, tuple[tp.Any, ...]]: ... - - @tp.overload def extract_graph_nodes( pytree: A, @@ -85,8 +60,6 @@ def extract_graph_nodes( prefix: tp.Any, validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None, ) -> tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]: ... - - def extract_graph_nodes( pytree: A, /, @@ -105,7 +78,7 @@ def extract_graph_nodes( prefix_leaves = broadcast_prefix( prefix, pytree, - is_leaf=lambda x: x is None, + prefix_is_leaf=lambda x: x is None, ) key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree) @@ -143,71 +116,64 @@ def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A: """Inserts graph nodes into a pytree.""" def _maybe_insert(x): - if isinstance(x, Extractable): + if isinstance(x, ExtractionIndex): return nodes[x.index] return x - return jax.tree_util.tree_map( - _maybe_insert, pytree, is_leaf=lambda x: isinstance(x, Extractable) + return jax.tree.map( + _maybe_insert, pytree, is_leaf=lambda x: isinstance(x, ExtractionIndex) ) -def extract_indexes( - pytree, - /, - types: tuple[type[E], ...] | type[E] = Extractable, # type: ignore[assignment] -) -> tuple[E, ...]: - """Extracts all indexes from a pytree.""" - indexes: list[E] = [] - for x in jax.tree.leaves( - pytree, is_leaf=lambda x: isinstance(x, Extractable) - ): - if isinstance(x, Extractable): - if not isinstance(x, types): - raise ValueError(f'Expected Extractable of type {types}, got {type(x)}') - indexes.append(x) # type: ignore[arg-type] - return tuple(indexes) - - -def replace_indexes( - pytree: A, - replace_fn: tp.Callable[[Extractable], tp.Any], +def check_consistent_aliasing( + node: tuple[tp.Any, ...], + prefix: tuple[tp.Any, ...], /, - clear: bool = False, -) -> A: - def _replace_map_fn(x): - if isinstance(x, Extractable): - return replace_fn(x) - elif clear: - return None - return x - - return jax.tree_util.tree_map( - _replace_map_fn, pytree, is_leaf=lambda x: isinstance(x, Extractable) - ) - - -def merge_extractable_states( - extractable_states: tp.Sequence[ExtractableStates], / + *, + node_prefixes: graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]] + | None = None, ): - if len(extractable_states) == 0: - raise ValueError('Expected at least one ExtractableStates object') + if node_prefixes is None: + node_prefixes = graph.RefMap() - graphdef = extractable_states[0].graphdef - flat_state: list[tuple[PathParts, tp.Any]] = [] + # collect all paths and prefixes for each node + for path, value in graph.iter_graph(node): + if graph.is_graph_node(value) or isinstance(value, graph.Variable): + if isinstance(value, Object): + value.check_valid_context( + lambda: f'Trying to extract graph node from different trace level, got {value!r}' + ) + if isinstance(value, graph.Variable): + if not value._trace_state.is_valid(): + raise ValueError( + f'Cannot extract graph node from different trace level, got {value!r}' + ) + if value in node_prefixes: + paths_prefixes = node_prefixes[value] + paths_prefixes.append((path, prefix)) + else: + node_prefixes[value] = [(path, prefix)] - for extractable_state in extractable_states: - flat_state.extend( - ((extractable_state.index, *path), value) - for state in extractable_state.states - for path, value in state.flat_state().items() - ) + # check for inconsistent aliasing + node_msgs = [] + for node, paths_prefixes in node_prefixes.items(): + unique_prefixes = {prefix for _, prefix in paths_prefixes} + if len(unique_prefixes) > 1: + path_prefix_repr = '\n'.join( + f' {"/".join(map(str,path)) if path else ""}: {prefix}' + for path, prefix in paths_prefixes + ) + nodes_msg = f'Node: {type(node)}\n{path_prefix_repr}' + node_msgs.append(nodes_msg) - state = State.from_flat_path(flat_state) - return graphdef, state + if node_msgs: + raise ValueError( + 'Inconsistent aliasing detected. The following nodes have different prefixes:\n' + + '\n'.join(node_msgs) + ) -def check_consistent_aliasing( +def check_all_consistent_aliasing( nodes: tuple[tp.Any, ...], prefixes: tuple[tp.Any, ...] ): node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]() @@ -218,7 +184,7 @@ def check_consistent_aliasing( if graph.is_graph_node(value): if isinstance(value, Object): value.check_valid_context( - f'Trying to extract graph node from different trace level, got {value!r}' + lambda: f'Trying to extract graph node from different trace level, got {value!r}' ) if value in node_prefixes: paths_prefixes = node_prefixes[value] @@ -286,3 +252,210 @@ def get_broadcast_state(tag: str) -> tp.Any: ) return stack[-1] + +# ----------------------------- +# to_tree/from_tree +# ----------------------------- + +def broadcast_prefix( + prefix_tree: tp.Any, + full_tree: tp.Any, + prefix_is_leaf: tp.Callable[[tp.Any], bool] | None = None, + tree_is_leaf: tp.Callable[[tp.Any], bool] | None = None, +) -> list[tp.Any]: + # If prefix_tree is not a tree prefix of full_tree, this code can raise a + # ValueError; use prefix_errors to find disagreements and raise more precise + # error messages. + result = [] + num_leaves = lambda t: jax.tree_util.tree_structure( + t, is_leaf=tree_is_leaf + ).num_leaves + add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree)) + jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=prefix_is_leaf) + return result + + +class GraphDefState(struct.PyTreeNode): + graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False) + state: graph.GraphState = struct.field(pytree_node=True) + +class StateOnly(struct.PyTreeNode): + state: graph.GraphState = struct.field(pytree_node=True) + + @property + def graphdef(self) -> graph.GraphDef[tp.Any]: + raise ValueError('No graphdef available in StateOnly') + + +@dataclasses.dataclass(frozen=True) +class StateSequence(tp.Sequence[graph.GraphState]): + graphdef_states: tuple[GraphDefState | StateOnly, ...] + + @tp.overload + def __getitem__(self, index: int) -> graph.GraphState: ... + @tp.overload + def __getitem__(self, index: slice) -> 'StateSequence': ... + def __getitem__(self, index): + if isinstance(index, slice): + return StateSequence(self.graphdef_states[index]) + elif isinstance(index, int): + return self.graphdef_states[index].state + else: + raise TypeError(f'Invalid index type: {type(index)}') + + def __len__(self): + return len(self.graphdef_states) + + def __iter__(self): + return (s.state for s in self.graphdef_states) + + +class TreeNode(struct.PyTreeNode): + metatata: tp.Any = struct.field(pytree_node=False) + graphdef_states: tuple[GraphDefState | StateOnly, ...] = struct.field( + pytree_node=True + ) + + @property + def graphdef(self) -> graph.GraphDef[tp.Any]: + return self.graphdef_states[0].graphdef + + @property + def state(self) -> graph.GraphState: + if len(self.graphdef_states) != 1: + raise ValueError( + f'Expected exactly one GraphDefState, got {len(self.graphdef_states)}' + ) + return self.graphdef_states[0].state + + @property + def states(self) -> tp.Sequence[graph.GraphState]: + return StateSequence(self.graphdef_states) + + @classmethod + def from_split( + cls, + graphdef: graph.GraphDef[tp.Any], + state: graph.GraphState, + /, + *states: graph.GraphState, + metadata: tp.Any = None, + ): + states = (state, *states) + return cls( + metadata, tuple(GraphDefState(graphdef, state) for state in states) + ) + + @classmethod + def from_states(cls, state: graph.GraphState, *states: graph.GraphState): + states = (state, *states) + return cls(None, tuple(StateOnly(state) for state in states)) + + @classmethod + def from_prefixes( + cls, + prefixes: tp.Iterable[tp.Any], + /, + *, + metadata: tp.Any = None, + ): + return cls(metadata, tuple(prefixes)) + + +def default_split_fn( + ctx: graph.SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf +) -> tp.Any: + return TreeNode.from_split(*ctx.split(leaf)) + + +def to_tree( + tree, + /, + *, + prefix: tp.Any = MISSING, + split_fn: tp.Callable[ + [graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any + ] = default_split_fn, + map_non_graph_nodes: bool = False, + ctxtag: str | None = None, +) -> tp.Any: + leaf_prefixes = broadcast_prefix( + prefix, + tree, + prefix_is_leaf=lambda x: x is None, + ) + leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(tree) + + assert len(leaf_keys) == len(leaf_prefixes) + leaves_out = [] + node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]() + + with graph.split_context(ctxtag) as split_ctx: + for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): + if graph.is_graph_node(leaf): + check_consistent_aliasing( + leaf, leaf_prefix, node_prefixes=node_prefixes + ) + tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf) + leaves_out.append(tree_node) + else: + if map_non_graph_nodes: + leaf = split_fn(split_ctx, keypath, leaf_prefix, leaf) + leaves_out.append(leaf) + + pytree_out = jax.tree.unflatten(treedef, leaves_out) + return pytree_out + + +def merge_tree_node( + ctx: graph.MergeContext, path: KeyPath, prefix: Prefix, leaf: Leaf +) -> tp.Any: + if not isinstance(leaf, TreeNode): + raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') + return ctx.merge(leaf.graphdef, *leaf.states) + + +def is_tree_node(x): + return isinstance(x, TreeNode) + + +def from_tree( + tree: tp.Any, + /, + *, + prefix: tp.Any = MISSING, + merge_fn: tp.Callable[ + [graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any + ] = merge_tree_node, + is_node_leaf: tp.Callable[[Leaf], bool] = is_tree_node, + is_leaf: tp.Callable[[Leaf], bool] = is_tree_node, + map_non_graph_nodes: bool = False, + ctxtag: str | None = None, +) -> tp.Any: + leaf_prefixes = broadcast_prefix( + prefix, + tree, + prefix_is_leaf=lambda x: x is None or is_leaf(x), + tree_is_leaf=is_leaf, + ) + leaf_keys, treedef = jax.tree_util.tree_flatten_with_path( + tree, is_leaf=is_leaf + ) + assert len(leaf_keys) == len(leaf_prefixes) + leaves_out = [] + + with graph.merge_context(ctxtag) as merge_ctx: + for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): + if is_node_leaf(leaf): + leaf_out = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) + leaves_out.append(leaf_out) + else: + if map_non_graph_nodes: + leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) + leaves_out.append(leaf) + + pytree_out = jax.tree.unflatten(treedef, leaves_out) + return pytree_out + +def clear_non_graph_nodes(tree): + return jax.tree.map(lambda x: x if graph.is_graph_node(x) else None, tree) \ No newline at end of file diff --git a/flax/nnx/nnx/filterlib.py b/flax/nnx/nnx/filterlib.py index 9113f12a7f..2e4de1a178 100644 --- a/flax/nnx/nnx/filterlib.py +++ b/flax/nnx/nnx/filterlib.py @@ -28,14 +28,6 @@ Filter = tp.Union[FilterLiteral, tuple['Filter', ...], list['Filter']] -@tp.runtime_checkable -class _HasTag(tp.Protocol): - tag: str - -@tp.runtime_checkable -class _HasType(tp.Protocol): - type: type - def to_predicate(filter: Filter) -> Predicate: """Converts a Filter to a predicate function. @@ -68,7 +60,7 @@ class WithTag: tag: str def __call__(self, path: PathParts, x: tp.Any): - return isinstance(x, _HasTag) and x.tag == self.tag + return hasattr(x, 'tag') and x.tag == self.tag def __repr__(self): return f'WithTag({self.tag!r})' @@ -91,7 +83,7 @@ class OfType: def __call__(self, path: PathParts, x: tp.Any): return isinstance(x, self.type) or ( - isinstance(x, _HasType) and issubclass(x.type, self.type) + hasattr(x, 'type') and issubclass(x.type, self.type) ) def __repr__(self): diff --git a/flax/nnx/nnx/graph.py b/flax/nnx/nnx/graph.py index fa4d58ec9c..4d7b33b36c 100644 --- a/flax/nnx/nnx/graph.py +++ b/flax/nnx/nnx/graph.py @@ -14,17 +14,18 @@ from __future__ import annotations +import contextlib import dataclasses import enum import functools import threading import typing as tp -from copy import deepcopy import jax import numpy as np import typing_extensions as tpe +from flax.core.frozen_dict import FrozenDict from flax.nnx.nnx import filterlib, reprlib from flax.nnx.nnx.proxy_caller import ( ApplyCaller, @@ -58,16 +59,6 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: return isinstance(x, (VariableState, np.ndarray, jax.Array)) -@dataclasses.dataclass -class GraphContext(threading.local): - update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field( - default_factory=dict - ) - - -GRAPH_CONTEXT = GraphContext() - - def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: return isinstance(x, (Variable, np.ndarray, jax.Array)) @@ -240,15 +231,49 @@ def __repr__(self) -> str: return repr(self._mapping) +class GraphDef(tp.Generic[Node]): + type: type[Node] + index: int + + +@dataclasses.dataclass(frozen=True, repr=False) +class NodeRef(GraphDef[Node], reprlib.Representable): + type: type[Node] + index: int + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + yield reprlib.Attr('type', self.type.__name__) + yield reprlib.Attr('index', self.index) + + def __penzai_repr__(self, path, subtree_renderer): + from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped] + + return pz_repr_lib.render_object_constructor( + object_type=type(self), + attributes={'type': self.type, 'index': self.index}, + path=path, + subtree_renderer=subtree_renderer, + ) + + +jax.tree_util.register_static(NodeRef) + + @dataclasses.dataclass(frozen=True, repr=False) -class NodeDef(tp.Generic[Node], reprlib.Representable): +class NodeDef(GraphDef[Node], reprlib.Representable): + """A dataclass that denotes the tree structure of a + :class:`Module`. A ``GraphDef`` can be generated by either + calling :func:`split` or :func:`graphdef` on the :class:`Module`.""" + type: tp.Type[Node] index: int attributes: tuple[Key, ...] - subgraphs: _HashableMapping[Key, tp.Union[NodeDef[tp.Any], Index]] + subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]] static_fields: _HashableMapping[Key, tp.Any] - leaves: _HashableMapping[Key, Index | None] + leaves: _HashableMapping[Key, NodeRef[tp.Any] | None] metadata: tp.Any + index_mapping: FrozenDict[Index, Index] | None @classmethod def create( @@ -256,10 +281,11 @@ def create( type: tp.Type[Node], index: int, attributes: tuple[Key, ...], - subgraphs: tp.Iterable[tuple[Key, tp.Union[NodeDef[tp.Any], Index]]], + subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]], static_fields: tp.Iterable[tuple[Key, tp.Any]], - leaves: tp.Iterable[tuple[Key, Index | None]], + leaves: tp.Iterable[tuple[Key, NodeRef[tp.Any] | None]], metadata: tp.Any, + index_mapping: tp.Mapping[Index, Index] | None, ): return cls( type=type, @@ -269,6 +295,9 @@ def create( static_fields=_HashableMapping(static_fields), leaves=_HashableMapping(leaves), metadata=metadata, + index_mapping=FrozenDict(index_mapping) + if index_mapping is not None + else None, ) def __nnx_repr__(self): @@ -283,63 +312,34 @@ def __nnx_repr__(self): ) yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves)) yield reprlib.Attr('metadata', self.metadata) - - def __penzai_repr__(self, path, subtree_renderer): - from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped] - return pz_repr_lib.render_object_constructor( - object_type=type(self), - attributes={ - 'type': self.type, - 'index': self.index, - 'attributes': self.attributes, - 'subgraphs': dict(self.subgraphs), - 'static_fields': dict(self.static_fields), - 'leaves': dict(self.leaves), - 'metadata': self.metadata, - }, - path=path, - subtree_renderer=subtree_renderer, + yield reprlib.Attr( + 'index_mapping', + reprlib.PrettyMapping(self.index_mapping) + if self.index_mapping is not None + else None, ) - -@dataclasses.dataclass(frozen=True, repr=False) -class GraphDef(tp.Generic[Node], reprlib.Representable): - """A dataclass that denotes the tree structure of a - :class:`Module`. A ``GraphDef`` can be generated by either - calling :func:`split` or :func:`graphdef` on the :class:`Module`.""" - - nodedef: NodeDef[Node] - index_mapping: dict[Index, Index] | None - - def __nnx_repr__(self): - yield reprlib.Object(type=type(self)) - - yield reprlib.Attr('nodedef', self.nodedef) - yield reprlib.Attr('index_mapping', self.index_mapping) - def __penzai_repr__(self, path, subtree_renderer): from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped] + return pz_repr_lib.render_object_constructor( - object_type=type(self), - attributes={ - 'nodedef': self.nodedef, - 'index_mapping': self.index_mapping, - }, - path=path, - subtree_renderer=subtree_renderer, + object_type=type(self), + attributes={ + 'type': self.type, + 'index': self.index, + 'attributes': self.attributes, + 'subgraphs': dict(self.subgraphs), + 'static_fields': dict(self.static_fields), + 'leaves': dict(self.leaves), + 'metadata': self.metadata, + 'index_mapping': dict(self.index_mapping) + if self.index_mapping is not None + else None, + }, + path=path, + subtree_renderer=subtree_renderer, ) - def __deepcopy__(self, memo=None): - nodedef = deepcopy(self.nodedef, memo) - index_mapping = deepcopy(self.index_mapping, memo) - return GraphDef(nodedef, index_mapping) - - def __hash__(self): - return hash(self.nodedef) - - def __eq__(self, other): - return isinstance(other, GraphDef) and self.nodedef == other.nodedef - def apply( self, state: GraphState, *states: GraphState ) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]: @@ -351,87 +351,71 @@ def _apply( module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) - return out, flatten(module)[:2] + return out, flatten(module) return CallableProxy(_apply, accessor) # type: ignore -def _graphdef_flatten(graphdef: GraphDef[Node]): - # refmap is opaque, we don't propagate it - static = (graphdef.nodedef, graphdef.index_mapping) - return (), static - - -def _graphdef_unflatten( - static: tuple[NodeDef[Node], dict[Index, Index] | None], _nodes: tuple[()] -): - nodedef, index_mapping = static - return GraphDef(nodedef, index_mapping) - -jax.tree_util.register_pytree_node( - GraphDef, - _graphdef_flatten, - _graphdef_unflatten, -) +jax.tree_util.register_static(NodeDef) GraphDefState = tuple[GraphDef[A], GraphState] def flatten( - x: Node, - /, - *, - idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[GraphDef[Node], GraphState, RefMap[tp.Any, Index]]: - refmap = RefMap[tp.Any, Index]() + node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None +) -> tuple[GraphDef[Node], GraphState]: + """Flattens a graph node into a (graphdef, state) pair. + + Args: + x: A graph node. + ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new + empty dictionary is created. This argument can be used to flatten a sequence of graph + nodes that share references. + """ + if ref_index is None: + ref_index = RefMap() flat_state: dict[PathParts, StateLeaf] = {} - nodedef = _graph_flatten((), refmap, flat_state, x) - assert not isinstance(nodedef, int) - if idxmap is not None: - index_to_index = compose_mapping(idxmap, refmap) - else: - index_to_index = None - graphdef = GraphDef(nodedef, index_to_index) - return graphdef, GraphState.from_flat_path(flat_state), refmap + graphdef = _graph_flatten((), ref_index, flat_state, node) + return graphdef, GraphState.from_flat_path(flat_state) def _graph_flatten( path: PathParts, - refmap: RefMap[tp.Any, Index], + ref_index: RefMap[tp.Any, Index], flat_state: dict[PathParts, StateLeaf], node: Node, -) -> NodeDef[Node] | int: +) -> GraphDef[Node] | NodeRef: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') - if node in refmap: - return refmap[node] + if node in ref_index: + return NodeRef(type(node), ref_index[node]) node_impl = get_node_impl(node) # only cache graph nodes if isinstance(node_impl, GraphNodeImpl): - index = len(refmap) - refmap[node] = index + index = len(ref_index) + ref_index[node] = index else: index = -1 - subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], Index]]] = [] + subgraphs: list[tuple[Key, GraphDef[Node] | NodeRef]] = [] static_fields: list[tuple[Key, tp.Any]] = [] - leaves: list[tuple[Key, Index | None]] = [] + leaves: list[tuple[Key, NodeRef | None]] = [] values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): - nodedef = _graph_flatten((*path, key), refmap, flat_state, value) + nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) subgraphs.append((key, nodedef)) elif isinstance(value, Variable): - if value in refmap: - leaves.append((key, refmap[value])) + if value in ref_index: + leaves.append((key, NodeRef(type(value), ref_index[value]))) else: flat_state[(*path, key)] = value.to_state() - variable_index = refmap[value] = len(refmap) - leaves.append((key, variable_index)) + variable_index = ref_index[value] = len(ref_index) + leaves.append((key, NodeRef(type(value), variable_index))) elif is_state_leaf(value): flat_state[(*path, key)] = value leaves.append((key, None)) @@ -446,6 +430,7 @@ def _graph_flatten( static_fields=static_fields, leaves=leaves, metadata=metadata, + index_mapping=None, ) return nodedef @@ -455,53 +440,59 @@ def unflatten( state: GraphState, /, *, - idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[Node, dict[Index, tp.Any]]: + index_ref: dict[Index, tp.Any] | None = None, + index_ref_cache: dict[Index, tp.Any] | None = None, +) -> Node: """Unflattens a graphdef into a node with the given state. Args: - graphdef: A NodeDef instance. + graphdef: A GraphDef instance. state: A State instance. - ref_cache: A mapping from indexes to existing nodes that can be reused. + index_ref: A mapping from indexes to nodes references found during the graph + traversal, defaults to None. If not provided, a new empty dictionary is + created. This argument can be used to unflatten a sequence of (graphdef, state) + pairs that share the same index space. + index_ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ - index_to_ref: dict[Index, tp.Any] = {} + if index_ref is None: + index_ref = {} + assert isinstance(graphdef, (NodeDef, NodeRef)) node = _graph_unflatten( - graphdef.nodedef, state.raw_mapping, index_to_ref, idxmap + graphdef, state.raw_mapping, index_ref, index_ref_cache ) - return node, index_to_ref - + return node def _graph_unflatten( - nodedef: tp.Union[NodeDef[Node], int], + nodedef: NodeDef[Node] | NodeRef[Node], state: tp.Mapping[Key, StateLeaf | tp.Mapping[Key, tp.Any]], - index_to_ref: dict[Index, tp.Any], - idxmap: dict[Index, tp.Any] | None, + index_ref: dict[Index, tp.Any], + index_ref_cache: dict[Index, tp.Any] | None, ) -> Node: """Recursive helper for graph_unflatten. Args: - nodedef: A NodeDef instance or an index to a node in the cache. + nodedef: A GraphDef instance or an index to a node in the cache. state: A mapping from attribute names to variables or subgraphs. index_to_ref: A mapping from indexes to nodes that have been traversed. If a node is already in the cache, it won't be traversed again. - ref_cache: A mapping from indexes to existing nodes that can be reused. + index_ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the nodedef. """ - if isinstance(nodedef, int): - return index_to_ref[nodedef] + if isinstance(nodedef, NodeRef): + return index_ref[nodedef.index] if not is_node_type(nodedef.type): raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') - if nodedef.index in index_to_ref: - raise RuntimeError(f'NodeDef index {nodedef.index} already used.') + if nodedef.index in index_ref: + raise RuntimeError(f'GraphDef index {nodedef.index} already used.') node_impl = get_node_impl_for_type(nodedef.type) @@ -524,9 +515,9 @@ def _get_children(): elif key in nodedef.subgraphs: # if the key is a subgraph we create an empty node subgraphdef = nodedef.subgraphs[key] - if isinstance(subgraphdef, int): + if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache - children[key] = index_to_ref[subgraphdef] + children[key] = index_ref[subgraphdef.index] else: # create a node from an empty state, reasoning: # * its a node with no state @@ -534,13 +525,13 @@ def _get_children(): # created nodes substate = {} children[key] = _graph_unflatten( - subgraphdef, substate, index_to_ref, idxmap + subgraphdef, substate, index_ref, index_ref_cache ) elif key in nodedef.leaves: - leaf_index = nodedef.leaves[key] - if leaf_index is not None and leaf_index in index_to_ref: + noderef = nodedef.leaves[key] + if noderef is not None and noderef.index in index_ref: # variable exists, take it from the cache - children[key] = index_to_ref[leaf_index] + children[key] = index_ref[noderef.index] else: # key for a variable is missing, raise an error raise ValueError( @@ -564,29 +555,29 @@ def _get_children(): assert isinstance(value, dict) subgraphdef = nodedef.subgraphs[key] - if isinstance(subgraphdef, int): - children[key] = index_to_ref[subgraphdef] + if isinstance(subgraphdef, NodeRef): + children[key] = index_ref[subgraphdef.index] else: children[key] = _graph_unflatten( - subgraphdef, value, index_to_ref, idxmap + subgraphdef, value, index_ref, index_ref_cache ) elif key in nodedef.leaves: if not is_state_leaf(value): raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}') - leaf_index = nodedef.leaves[key] + noderef = nodedef.leaves[key] - if leaf_index is None: + if noderef is None: # if the leaf is None, it means that the value was originally # a non-VariableState leaf, however we allow providing a # VariableState presumbly created by modifying the State if isinstance(value, VariableState): value = value.to_variable() children[key] = value - elif leaf_index in index_to_ref: + elif noderef.index in index_ref: # add an existing variable - children[key] = index_to_ref[leaf_index] + children[key] = index_ref[noderef.index] else: # its a unseen variable, create a new one if not isinstance(value, VariableState): @@ -595,8 +586,8 @@ def _get_children(): ) # when idxmap is present, check if the Varable exists there # and update existing variables if it does - if idxmap is not None and leaf_index in idxmap: - variable = idxmap[leaf_index] + if index_ref_cache is not None and noderef.index in index_ref_cache: + variable = index_ref_cache[noderef.index] if not isinstance(variable, Variable): raise ValueError( f'Expected a Variable type for {key!r}, but got {type(variable)}.' @@ -606,7 +597,7 @@ def _get_children(): assert isinstance(value, VariableState) variable = value.to_variable() children[key] = variable - index_to_ref[leaf_index] = variable + index_ref[noderef.index] = variable else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') @@ -615,8 +606,8 @@ def _get_children(): if isinstance(node_impl, GraphNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle - if idxmap is not None and nodedef.index in idxmap: - node = idxmap[nodedef.index] + if index_ref_cache is not None and nodedef.index in index_ref_cache: + node = index_ref_cache[nodedef.index] if type(node) != nodedef.type: raise ValueError( f'Expected a node of type {nodedef.type} for index ' @@ -625,7 +616,7 @@ def _get_children(): node_impl.clear(node) else: node = node_impl.create_empty(nodedef.metadata) - index_to_ref[nodedef.index] = node + index_ref[nodedef.index] = node children = _get_children() node_impl.init(node, tuple(children.items())) else: @@ -858,10 +849,124 @@ def _graph_update_static( # UpdateContext # -------------------------------------------------------- +@dataclasses.dataclass +class GraphContext(threading.local): + update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field( + default_factory=dict + ) + ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list) + index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list) + + +GRAPH_CONTEXT = GraphContext() -# -------------------------------------------------------- -# UpdateContext -# -------------------------------------------------------- + +@dataclasses.dataclass +class SplitContext: + ctxtag: str | None + ref_index: RefMap[tp.Any, Index] + + @tp.overload + def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... + @tp.overload + def split( + self, graph_node: A, first: filterlib.Filter, / + ) -> tuple[GraphDef[A], GraphState]: ... + @tp.overload + def split( + self, + graph_node: A, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... + def split( + self, node: A, *filters: filterlib.Filter + ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: + ctx = ( + current_update_context(self.ctxtag) if self.ctxtag is not None else None + ) + graphdef, state = flatten(node, self.ref_index) + states = _split_state(state, filters) + if ctx is not None: + if ctx.index_ref is not None and isinstance(graphdef, NodeDef): + index_to_index = compose_mapping(ctx.index_ref, self.ref_index) + graphdef = dataclasses.replace( + graphdef, index_mapping=FrozenDict(index_to_index) + ) + + return graphdef, *states + + +@contextlib.contextmanager +def split_context(ctxtag: str | None = None): + index_ref: RefMap[tp.Any, Index] = RefMap() + flatten_ctx = SplitContext(ctxtag, index_ref) + GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx) + + try: + yield flatten_ctx + finally: + GRAPH_CONTEXT.ref_index_stack.pop() + if ctxtag is not None: + ctx = current_update_context(ctxtag) + ctx.flatten_end(index_ref) + del flatten_ctx.ref_index + del flatten_ctx.ctxtag + + +@dataclasses.dataclass +class MergeContext: + ctxtag: str | None + index_ref: dict[Index, tp.Any] + + def merge( + self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState + ) -> A: + ctx = ( + current_update_context(self.ctxtag) if self.ctxtag is not None else None + ) + if ( + ctx is not None + and isinstance(graphdef, NodeDef) + and graphdef.index_mapping is not None + ): + # outer merge (4), create index_ref_cache + assert ctx.ref_index is not None + index_ref_cache = compose_mapping_reversed( + ctx.ref_index, graphdef.index_mapping + ) + else: + # inner merge (2) + index_ref_cache = None + + state = State.merge(state, *states) + node = unflatten( + graphdef, + state, + index_ref=self.index_ref, + index_ref_cache=index_ref_cache, + ) + return node + + +@contextlib.contextmanager +def merge_context(ctxtag: str | None = None): + index_ref: dict[Index, tp.Any] = {} + + unflatten_ctx = MergeContext(ctxtag, index_ref) + GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx) + + try: + yield unflatten_ctx + finally: + GRAPH_CONTEXT.index_ref_stack.pop() + if ctxtag is not None: + ctx = current_update_context(ctxtag) + ctx.unflatten_end(index_ref) + del unflatten_ctx.index_ref + del unflatten_ctx.ctxtag @dataclasses.dataclass @@ -869,8 +974,8 @@ class UpdateContext: """A context manager for handling complex state updates.""" tag: str - refmap: RefMap[tp.Any, Index] | None - idxmap: dict[Index, tp.Any] | None + ref_index: RefMap[tp.Any, Index] | None + index_ref: dict[Index, tp.Any] | None # define hash and eq to make this an opaque object def __hash__(self): @@ -879,14 +984,23 @@ def __hash__(self): def __eq__(self, other): return isinstance(other, UpdateContext) + def flatten_end(self, ref_index: RefMap[tp.Any, Index]): + if self.ref_index is None: + # outer split (1), store the references + self.ref_index = ref_index + else: + # inner split (3), clear index_ref + self.index_ref = None + + def unflatten_end(self, index_ref: dict[Index, tp.Any]): + self.index_ref = index_ref + @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... - @tp.overload def split( self, graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState]: ... - @tp.overload def split( self, @@ -896,7 +1010,6 @@ def split( /, *filters: filterlib.Filter, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... - def split( self, node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: @@ -908,7 +1021,7 @@ def split( Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): @@ -963,24 +1076,19 @@ def split( :class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no filters are passed, a single :class:`State` is returned. """ - graphdef, state, refmap = flatten(node, idxmap=self.idxmap) - - states: GraphState | tuple[GraphState, ...] - if len(filters) == 0: - states = (state,) - elif len(filters) == 1: - states = (state.split(filters[0]),) - else: - states = state.split(filters[0], filters[1], *filters[2:]) + ref_index: RefMap[tp.Any, Index] = RefMap() + graphdef, state = flatten(node, ref_index) + states = _split_state(state, filters) + + if self.index_ref is not None and isinstance(graphdef, NodeDef): + index_to_index = compose_mapping(self.index_ref, ref_index) + graphdef = dataclasses.replace( + graphdef, index_mapping=FrozenDict(index_to_index) + ) - if self.refmap is None: - self.refmap = refmap + self.flatten_end(ref_index) - if graphdef.index_mapping is not None: - # clear idxmap to remove any references to tracers - self.idxmap = None - - return graphdef, states[0], *states[1:] + return graphdef, *states def merge( self, @@ -989,22 +1097,30 @@ def merge( *states: GraphState, ) -> A: """merge""" - if self.refmap is None: - raise ValueError('Cannot update a graphdef without refmap.') - - if states: - state = GraphState.merge(state, *states) + if not isinstance(graphdef, NodeDef): + raise ValueError( + f'Expected a NodeDef instance, but got {type(graphdef)}.' + ) + if self.ref_index is None: + raise ValueError('Cannot merge without ref_index.') - if graphdef.index_mapping is None: - node, self.idxmap = unflatten(graphdef, state) - else: - index_to_ref = compose_mapping_reversed( - self.refmap, graphdef.index_mapping + if graphdef.index_mapping is not None: + # outer merge (4), create index_ref_cache + assert self.ref_index is not None + index_ref_cache = compose_mapping_reversed( + self.ref_index, graphdef.index_mapping ) - node, _idxmap = unflatten(graphdef, state, idxmap=index_to_ref) - # clear references - self.refmap = None - self.idxmap = None + else: + # inner merge (2) + index_ref_cache = None + + state = State.merge(state, *states) + index_ref: dict[Index, tp.Any] = {} + node = unflatten( + graphdef, state, index_ref=index_ref, index_ref_cache=index_ref_cache + ) + + self.unflatten_end(index_ref) return node @@ -1033,8 +1149,8 @@ def __exit__(self, *args): ctx = stack.pop() # clear references - ctx.refmap = None - ctx.idxmap = None + del ctx.ref_index + del ctx.index_ref if not stack: del GRAPH_CONTEXT.update_context_stacks[self.tag] @@ -1154,19 +1270,24 @@ def current_update_context(tag: str) -> UpdateContext: # Functional API # -------------------------------------------------------- +def _split_state( + state: GraphState, + filters: tuple[filterlib.Filter, ...], +) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]: + if not filters: + return (state,) + states = state.split(*filters) + if isinstance(states, State): + return (states,) + assert len(states) > 0 + return states @tp.overload def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... - - @tp.overload def split( - graph_node: A, - first: filterlib.Filter, - /, + graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState]: ... - - @tp.overload def split( graph_node: A, @@ -1175,8 +1296,6 @@ def split( /, *filters: filterlib.Filter, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... - - def split( node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: @@ -1188,7 +1307,7 @@ def split( Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): @@ -1248,17 +1367,9 @@ def split( ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no filters are passed, a single ``State`` is returned. """ - graphdef, state, _ = flatten(node) - - states: GraphState | tuple[GraphState, ...] - if len(filters) == 0: - states = (state,) - elif len(filters) == 1: - states = (state.split(filters[0]),) - else: - states = state.split(filters[0], filters[1], *filters[2:]) - - return graphdef, states[0], *states[1:] + graphdef, state = flatten(node) + states = _split_state(state, filters) + return graphdef, *states def merge( @@ -1274,7 +1385,7 @@ def merge( Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): @@ -1302,10 +1413,8 @@ def merge( Returns: The merged :class:`Module`. """ - if states: - state = GraphState.merge(state, *states) - - node, _ = unflatten(graphdef, state) + state = GraphState.merge(state, *states) + node = unflatten(graphdef, state) return node @@ -1392,7 +1501,7 @@ def state( Returns: One or more :class:`State` mappings. """ - state = flatten(node)[1] + _, state = flatten(node) states: GraphState | tuple[GraphState, ...] if len(filters) == 0: @@ -1421,7 +1530,7 @@ def graphdef(node: tp.Any, /) -> GraphDef[tp.Any]: Returns: The :class:`GraphDef` of the :class:`Module` object. """ - graphdef, _, _ = flatten(node) + graphdef, _ = flatten(node) return graphdef diff --git a/flax/nnx/nnx/nn/linear.py b/flax/nnx/nnx/nn/linear.py index ad7f646b61..3acd7afd0f 100644 --- a/flax/nnx/nnx/nn/linear.py +++ b/flax/nnx/nnx/nn/linear.py @@ -205,7 +205,7 @@ def kernel_init_wrap(rng, shape, dtype): * np.prod(shape[n_batch_axis : n_in_features + n_batch_axis]), np.prod(shape[-n_out_features:]), ) - flat_shape = jax.tree_util.tree_map(int, flat_shape) + flat_shape = jax.tree.map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) if isinstance(kernel, variables.VariableMetadata): kernel.raw_value = jnp.reshape(kernel.raw_value, shape) diff --git a/flax/nnx/nnx/object.py b/flax/nnx/nnx/object.py index 41cf8fee28..676a61a748 100644 --- a/flax/nnx/nnx/object.py +++ b/flax/nnx/nnx/object.py @@ -122,13 +122,13 @@ def __setattr__(self, name: str, value: Any) -> None: def _setattr(self, name: str, value: tp.Any) -> None: self.check_valid_context( - f"Cannot mutate '{type(self).__name__}' from different trace level" + lambda: f"Cannot mutate '{type(self).__name__}' from different trace level" ) object.__setattr__(self, name, value) - def check_valid_context(self, error_msg: str) -> None: + def check_valid_context(self, error_msg: tp.Callable[[], str]) -> None: if not self._object__state.trace_state.is_valid(): - raise errors.TraceContextError(error_msg) + raise errors.TraceContextError(error_msg()) def __deepcopy__(self: G, memo=None) -> G: graphdef, state = graph.split(self) diff --git a/flax/nnx/nnx/rnglib.py b/flax/nnx/nnx/rnglib.py index 3097da2a65..3d903ba5e3 100644 --- a/flax/nnx/nnx/rnglib.py +++ b/flax/nnx/nnx/rnglib.py @@ -28,11 +28,13 @@ from __future__ import annotations import dataclasses +import functools import typing as tp import jax import jax.numpy as jnp +from flax import struct from flax.nnx.nnx import graph from flax.nnx.nnx.state import State from flax.nnx.nnx.variables import Variable @@ -40,6 +42,7 @@ from flax.nnx.nnx.filterlib import All from flax.nnx.nnx.object import Object +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) Counts = list[int] AxesValue = tp.Union[int, None] SplitPattern = tp.Union[AxesValue, tuple[AxesValue, ...]] @@ -53,15 +56,13 @@ class Missing: class RngState(Variable[jax.Array]): - pass + tag: str -class RngCount(RngState): - tag: str +class RngCount(RngState): ... -class RngKey(RngState): - tag: str +class RngKey(RngState): ... NotKey = filterlib.All(RngState, filterlib.Not(RngKey)) @@ -84,7 +85,7 @@ def __post_init__(self): def __call__(self) -> jax.Array: self.check_valid_context( - 'Cannot call RngStream from a different trace level' + lambda: 'Cannot call RngStream from a different trace level' ) key = jax.random.fold_in(self.key.value, self.count.value) self.count.value += 1 @@ -200,10 +201,11 @@ def __init__( rngs['default'] = default for name, value in rngs.items(): + key = jax.random.key(value) if isinstance(value, int) else value stream = RngStream( tag=name, - key=jax.random.key(value) if isinstance(value, int) else value, - count=jnp.array(0, dtype=jnp.uint32), + key=key, + count=jnp.zeros(key.shape, dtype=jnp.uint32), ) setattr(self, name, stream) @@ -277,20 +279,171 @@ def split_key(key: tp.Any) -> jax.Array: return ForkStates(split_keys, split_counts, broadcast_keys, broadcast_counts) +StreamBackup = ( + tuple[RngStream, jax.Array, jax.Array] | tuple[RngStream, jax.Array] +) + +class SplitBackups(struct.PyTreeNode, tp.Iterable[StreamBackup]): + backups: list[StreamBackup] + + def __iter__(self) -> tp.Iterator[StreamBackup]: + return iter(self.backups) + + def __enter__(self): + return self + + def __exit__(self, *args): + restore_rngs(self) + + +@tp.overload +def split_rngs( + node: tp.Any, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., +) -> SplitBackups: ... +@tp.overload +def split_rngs( + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., +) -> tp.Callable[[F], F]: ... +def split_rngs( + node: tp.Any = MISSING, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., +) -> SplitBackups | tp.Callable[[F], F]: + """Splits the (nested) Rng states of the given node. + + Args: + node: the base node containing the rng states to split. + splits: an integer or tuple of integers specifying the + shape of the split rng keys. + only: a Filter selecting which rng states to split. + + Returns: + A SplitBackups iterable if ``node`` is provided, otherwise a + decorator that splits the rng states of the inputs to the + decorated function. + + Example:: + + >>> from flax import nnx + ... + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> _ = nnx.split_rngs(rngs, splits=5) + >>> rngs.params.key.shape, rngs.dropout.key.shape + ((5,), (5,)) + + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> _ = nnx.split_rngs(rngs, splits=(2, 5)) + >>> rngs.params.key.shape, rngs.dropout.key.shape + ((2, 5), (2, 5)) + + + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> _ = nnx.split_rngs(rngs, splits=5, only='params') + >>> rngs.params.key.shape, rngs.dropout.key.shape + ((5,), ()) + + Once split, random state can be used with transforms like :func:`nnx.vmap`:: + + >>> class Model(nnx.Module): + ... def __init__(self, rngs): + ... self.linear = nnx.Linear(2, 3, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5, rngs=rngs) + ... + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> _ = nnx.split_rngs(rngs, splits=5, only='params') + ... + >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) + ... + >>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + ... def create_model(rngs): + ... return Model(rngs) + ... + >>> model = create_model(rngs) + >>> model.dropout.rngs.params.key.shape + (5,) + + ``split_rngs`` returns a SplitBackups object that can be used to restore the + original unsplit rng states using :func:`nnx.restore_rngs`, this is useful + when you only want to split the rng states temporarily:: + + >>> rngs = nnx.Rngs(params=0, dropout=1) + ... + >>> backups = nnx.split_rngs(rngs, splits=5, only='params') + >>> model = create_model(rngs) + >>> nnx.restore_rngs(backups) + ... + >>> model.dropout.rngs.params.key.shape + () + + SplitBackups can also be used as a context manager to automatically restore + the rng states when exiting the context:: + + >>> rngs = nnx.Rngs(params=0, dropout=1) + ... + >>> with nnx.split_rngs(rngs, splits=5, only='params'): + ... model = create_model(rngs) + ... + >>> model.dropout.rngs.params.key.shape + () + + >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) + ... + >>> @nnx.split_rngs(splits=5, only='params') + ... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + ... def create_model(rngs): + ... return Model(rngs) + ... + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> model = create_model(rngs) + >>> model.dropout.rngs.params.key.shape + () + + + """ + if isinstance(node, Missing): + + def split_rngs_decorator(f: F) -> F: + @functools.wraps(f) + def split_rngs_wrapper(*args, **kwargs): + with split_rngs((args, kwargs), splits=splits, only=only): + return f(*args, **kwargs) + + return tp.cast(F, split_rngs_wrapper) + + return split_rngs_decorator + + predicate = filterlib.to_predicate(only) + backups: list[StreamBackup] = [] + for path, stream in graph.iter_graph(node): + if ( + isinstance(stream, RngStream) + and predicate((*path, 'key'), stream.key) + and predicate((*path, 'count'), stream.count) + ): + key = stream() + backups.append((stream, stream.key.value, stream.count.value)) + stream.key.value = jax.random.split(key, splits) + stream.count.value = jnp.zeros(stream.key.value.shape, dtype=jnp.uint32) + + return SplitBackups(backups) + def backup_keys(node: tp.Any, /): - backups: list[tuple[RngStream, jax.Array]] = [] + backups: list[StreamBackup] = [] for _, stream in graph.iter_graph(node): if isinstance(stream, RngStream): backups.append((stream, stream.key.value)) return backups -def restore_keys(backups: list[tuple[RngStream, jax.Array]], /): - for stream, key in backups: - stream.key.value = key - - def reseed(node, /, **stream_keys: RngValue): """Update the keys of the specified RNG streams with new keys. @@ -340,3 +493,9 @@ def reseed(node, /, **stream_keys: RngValue): key = jax.random.key(key) stream.key.value = key stream.count.value = jnp.array(0, dtype=jnp.uint32) +def restore_rngs(backups: tp.Iterable[StreamBackup], /): + for backup in backups: + stream = backup[0] + stream.key.value = backup[1] # key + if len(backup) == 3: + stream.count.value = backup[2] # count diff --git a/flax/nnx/nnx/spmd.py b/flax/nnx/nnx/spmd.py index fd7067c0ae..075186e652 100644 --- a/flax/nnx/nnx/spmd.py +++ b/flax/nnx/nnx/spmd.py @@ -20,7 +20,6 @@ from jax.sharding import Mesh, PartitionSpec from flax.nnx.nnx import variables -from flax.nnx.nnx.state import State from flax.typing import ( Array, ArrayPytree, # pylint: disable=invalid-name @@ -33,50 +32,41 @@ PARTITION_NAME = 'partition_name' -@tp.runtime_checkable -class HasSharding(tp.Protocol): - sharding: tp.Optional[Sharding] - - -def add_axis( - state: State, index: int, params: tp.Mapping[tp.Any, tp.Any] -) -> State: +def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A: axis_name = _get_partition_name(params) def _add_axis(x: tp.Any): if isinstance(x, variables.VariableState): - if isinstance(x, HasSharding) and x.sharding is not None: - sharding = list(x.sharding) + if hasattr(x, 'sharding') and x.sharding is not None: + sharding: list[str | None] = list(x.sharding) while len(sharding) < index: sharding.append(None) sharding.insert(index, axis_name) - x.sharding = tuple(sharding) + x.sharding = tuple(sharding) # type: ignore x.add_axis(axis_name, index) return x - return jax.tree_util.tree_map( - _add_axis, state, is_leaf=lambda x: isinstance(x, variables.VariableState) + return jax.tree.map( + _add_axis, tree, is_leaf=lambda x: isinstance(x, variables.VariableState) ) -def remove_axis( - state: State, index: int, params: tp.Mapping[tp.Any, tp.Any] -) -> State: +def remove_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A: axis_name = _get_partition_name(params) def _remove_axis(x: tp.Any): if isinstance(x, variables.VariableState): - if isinstance(x, HasSharding) and x.sharding is not None: + if hasattr(x, 'sharding') and x.sharding is not None: sharding = list(x.sharding) assert sharding.pop(index) == axis_name x.sharding = tuple(sharding) x.remove_axis(axis_name, index) return x - return jax.tree_util.tree_map( + return jax.tree.map( _remove_axis, - state, + tree, is_leaf=lambda x: isinstance(x, variables.VariableState), ) @@ -101,21 +91,21 @@ def _maybe_replicate(x): def f(x): if isinstance(x, (variables.VariableState, variables.Variable)): - if isinstance(x, HasSharding) and x.sharding: + if hasattr(x, 'sharding') and x.sharding: return x.replace(PartitionSpec(*x.sharding)) else: return x.replace(_maybe_replicate(x.value)) return _maybe_replicate(x) - return jax.tree_util.tree_map( + return jax.tree.map( f, tree, is_leaf=lambda x: isinstance(x, variables.VariableState) ) def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: spec = get_partition_spec(tree) - sharding = jax.tree_util.tree_map( + sharding = jax.tree.map( lambda p: jax.sharding.NamedSharding(mesh, p), spec ) return sharding @@ -161,7 +151,7 @@ def with_sharding_constraint( if axis_resources is None: return x # Translate logical names to mesh assignments. - return jax.tree_util.tree_map( + return jax.tree.map( functools.partial(_with_sharding_constraint, mesh=mesh), x, axis_resources, diff --git a/flax/nnx/nnx/state.py b/flax/nnx/nnx/state.py index ed4f8c2caa..6453b5acd6 100644 --- a/flax/nnx/nnx/state.py +++ b/flax/nnx/nnx/state.py @@ -327,10 +327,10 @@ def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]: Returns: The merged ``State``. """ - states = (state, *states) + if not states: + return state - if len(states) == 1: - return states[0] + states = (state, *states) new_state: FlatState[V] = {} diff --git a/flax/nnx/nnx/tracers.py b/flax/nnx/nnx/tracers.py index 897c98b8bd..a50a0ae867 100644 --- a/flax/nnx/nnx/tracers.py +++ b/flax/nnx/nnx/tracers.py @@ -23,14 +23,9 @@ from flax.nnx.nnx import reprlib -@tp.runtime_checkable -class Tracer(tp.Protocol): - _trace: jax.core.Trace - - -def get_top_trace(pytree: tp.Union[tp.Any, Tracer]) -> MainTrace: +def get_top_trace(pytree: tp.Any) -> MainTrace: """Returns the main top trace of a sequence of tracers.""" - if isinstance(pytree, Tracer): + if hasattr(pytree, '_trace'): return pytree._trace.main return jax.core.find_top_trace(jax.tree_util.tree_leaves(pytree)).main diff --git a/flax/nnx/nnx/transforms/__init__.py b/flax/nnx/nnx/transforms/__init__.py index af25fe62d9..1e2650788c 100644 --- a/flax/nnx/nnx/transforms/__init__.py +++ b/flax/nnx/nnx/transforms/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import deprecated \ No newline at end of file diff --git a/flax/nnx/nnx/transforms/autodiff.py b/flax/nnx/nnx/transforms/autodiff.py new file mode 100644 index 0000000000..220d6efd3c --- /dev/null +++ b/flax/nnx/nnx/transforms/autodiff.py @@ -0,0 +1,782 @@ +from collections import deque +import dataclasses +import functools +import typing as tp + + +from flax import struct +from flax.nnx.nnx import ( + extract, + filterlib, + graph, + variables, +) +from flax.nnx.nnx.state import State +import jax +import jax.core +import jax.stages + +from flax.nnx.nnx.transforms import general +from flax.nnx.nnx.transforms.transforms import resolve_kwargs + + +A = tp.TypeVar('A') +# C = tp.TypeVar('C') +# B = tp.TypeVar('B') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +# G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) +# M = tp.TypeVar('M', bound=Module) +# MA = tp.TypeVar('MA', bound=Module) +# N = tp.TypeVar('N', bound=Module) +# StrInt = tp.TypeVar('StrInt', str, int) +AxisName = tp.Hashable +# Leaves = tp.List[Leaf] +# Index = int + + +class Missing: + pass + + +MISSING = Missing() + + +# ------------------------------- +# grad +# ------------------------------- + + +@dataclasses.dataclass(frozen=True) +class DiffState: + argnum: int + filter: filterlib.Filter + + +@dataclasses.dataclass(eq=False) +class GradFn: + f: tp.Callable[..., tp.Any] + has_aux: bool + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args): + # rebuild diff_state from substates in args + nondiff_states: deque[State | None] = extract.get_broadcast_state('grad') + + def _grad_merge_fn( + ctx: graph.MergeContext, path, prefix, value: extract.TreeNode + ): + nondiff = nondiff_states.popleft() + if nondiff is None: + return ctx.merge(value.graphdef, value.state) + else: + return ctx.merge(value.graphdef, value.state, nondiff) + + args = extract.from_tree(pure_args, merge_fn=_grad_merge_fn, ctxtag='grad') + + out = self.f(*args) + + args_out = extract.clear_non_graph_nodes(args) + pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag='grad') + + if self.has_aux: + loss, pure_aux = pure_out + fn_out = (loss, (pure_args_out, pure_aux)) + else: + loss = pure_out + fn_out = (loss, pure_args_out) + + return fn_out + + +def _grad_general( + f: tp.Callable[..., tp.Any], + argnums: int | DiffState | tp.Sequence[int | DiffState], + has_aux: bool, + holomorphic: bool, + allow_int: bool, + reduce_axes: tp.Sequence[AxisName], + return_value: bool, +) -> tp.Callable[..., tp.Any]: + transform = jax.value_and_grad if return_value else jax.grad + + if isinstance(argnums, (int, DiffState)): + jax_argnums = argnums.argnum if isinstance(argnums, DiffState) else argnums + else: + jax_argnums = tuple( + x.argnum if isinstance(x, DiffState) else x for x in argnums + ) + + _argnums = (argnums,) if isinstance(argnums, (int, DiffState)) else argnums + index_filter: dict[int, DiffState] = {} + for argnum in _argnums: + index = argnum.argnum if isinstance(argnum, DiffState) else argnum + if index in index_filter: + raise ValueError(f'argnum {index} is repeated in argnums') + index_filter[index] = ( + dataclasses.replace(argnum, argnum=-1) + if isinstance(argnum, DiffState) + else DiffState(-1, variables.Param) + ) + + gradded_fn = transform( + GradFn(f, has_aux), + argnums=jax_argnums, + has_aux=True, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + + @graph.update_context('grad') + def grad_wrapper(*args, **kwargs): + args = resolve_kwargs(f, args, kwargs) + del kwargs + nondiff_states: deque[State | None] = deque() + + def _grad_split_fn( + ctx: graph.SplitContext, path, prefix: DiffState | None, value + ): + if prefix is None: + nondiff_states.append(None) + return extract.TreeNode.from_split(*ctx.split(value)) + else: + graphdef, diff, nondiff = ctx.split(value, prefix.filter, ...) + nondiff_states.append(nondiff) + return extract.TreeNode.from_split(graphdef, diff) + + arg_filters = tuple(index_filter.get(i) for i in range(len(args))) + pure_args = extract.to_tree( + args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad' + ) + + with extract.broadcast_state('grad', nondiff_states): + fn_out = gradded_fn(*pure_args) + + def process_grads(grads): + return jax.tree.map( + lambda x: x.state if isinstance(x, extract.TreeNode) else x, + grads, + is_leaf=lambda x: isinstance(x, extract.TreeNode), + ) + + def process_out(pure_out: A, /) -> A: + return extract.from_tree(pure_out, ctxtag='grad') + + if return_value: + # unpack value_and_grad output + if has_aux: + (loss, (pure_args_out, pure_aux)), grads = fn_out + grads = process_grads(grads) + _args_out, aux = process_out((pure_args_out, pure_aux)) + return (loss, aux), grads + else: + (loss, pure_args_out), grads = fn_out + grads = process_grads(grads) + _args_out = process_out(pure_args_out) + return loss, grads + else: + # unpack grad output + if has_aux: + grads, (pure_args_out, pure_aux) = fn_out + grads = process_grads(grads) + _args_out, aux = process_out((pure_args_out, pure_aux)) + return grads, aux + else: + grads, pure_args_out = fn_out + grads = process_grads(grads) + _args_out = process_out(pure_args_out) + return grads + + return grad_wrapper + + +@tp.overload +def grad( + f: tp.Callable[..., tp.Any], + *, + argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tp.Any]: ... +@tp.overload +def grad( + *, + argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... +def grad( + f: tp.Callable[..., tp.Any] | Missing = MISSING, + *, + argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> ( + tp.Callable[..., tp.Any] + | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] +): + """Lifted version of ``jax.grad`` that can handle Modules / graph nodes as + arguments. + + The differentiable state of each graph node is defined by the `wrt` filter, + which by default is set to `nnx.Param`. Internally the ``State`` of + graph nodes is extracted, filtered according to `wrt` filter, and + passed to the underlying ``jax.grad`` function. The gradients + of graph nodes are of type ``State``. + + Example:: + + >>> from flax import nnx + >>> import jax.numpy as jnp + ... + >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + >>> x = jnp.ones((1, 2)) + >>> y = jnp.ones((1, 3)) + ... + >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) + >>> grad_fn = nnx.grad(loss_fn) + ... + >>> grads = grad_fn(m, x, y) + >>> jax.tree.map(jnp.shape, grads) + State({ + 'bias': VariableState( + type=Param, + value=(3,) + ), + 'kernel': VariableState( + type=Param, + value=(2, 3) + ) + }) + + Args: + fun: Function to be differentiated. Its arguments at positions specified by + ``argnums`` should be arrays, scalars, graph nodes or standard Python + containers. Argument arrays in the positions specified by ``argnums`` must + be of inexact (i.e., floating-point or complex) type. It should return a + scalar (which includes arrays with shape ``()`` but not arrays with shape + ``(1,)`` etc.) + argnums: Optional, integer or sequence of integers. Specifies which + positional argument(s) to differentiate with respect to (default 0). + has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be + holomorphic. If True, inputs and outputs must be complex. Default False. + allow_int: Optional, bool. Whether to allow differentiating with + respect to integer valued inputs. The gradient of an integer input will + have a trivial vector-space dtype (float0). Default False. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fun`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + gradient will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a + function that computes the total gradient while ``grad(f)`` will create + one that computes the per-example gradient. + """ + + if isinstance(f, Missing): + return functools.partial( + grad, + argnums=argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + return _grad_general( + f, + argnums, + has_aux, + holomorphic, + allow_int, + reduce_axes, + return_value=False, + ) + + +@tp.overload +def value_and_grad( + f: tp.Callable[..., tp.Any], + *, + argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tp.Any]: ... +@tp.overload +def value_and_grad( + *, + argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... +def value_and_grad( + f: tp.Callable[..., tp.Any] | Missing = MISSING, + *, + argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> ( + tp.Callable[..., tp.Any] + | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] +): + if isinstance(f, Missing): + return functools.partial( + value_and_grad, + argnums=argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + return _grad_general( + f, + argnums, + has_aux, + holomorphic, + allow_int, + reduce_axes, + return_value=True, + ) + + +def _custom_vjp_merge_fn( + ctx: graph.MergeContext, + path, + prefix: bool | DiffState, + value: extract.TreeNode, + *, + nondiff_states: deque[extract.GraphDefState], +): + nondiff = nondiff_states.popleft() + return ctx.merge(nondiff.graphdef, value.state, nondiff.state) + + +def _custom_vjp_split_fn( + ctx: graph.SplitContext, + path, + prefix: bool | DiffState, + value, + *, + nondiff_states: deque[extract.GraphDefState], +): + if prefix is False: + # pure non-differentiable arg, we pass all the state through + # but we return TreeNode.from_split with a graphdef to we can call from_tree + # on the nondiff args during the backward pass + graphdef, passed = ctx.split(value) + broadcast = State({}) + nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) + return extract.TreeNode.from_split(graphdef, passed) + elif prefix is True: + # pure differentiable arg, we pass all the state through + # but we reutrn a TreeNode.from_states which doesn't have a graphdef + # in order to keep the gradients clean from any metadata + graphdef, passed = ctx.split(value) + broadcast = State({}) + nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) + return extract.TreeNode.from_states(passed) + else: + # differentiable arg with DiffState filter, we use the filter to split the state + # as before we return a TreeNode.from_states to keep the gradients clean + # from any metadata, the non-differentiable state is stored in a deque + # which is broadcasted during the forward pass + graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) + nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) + return extract.TreeNode.from_states(passed) + + +class CustomVjpMetadata(struct.PyTreeNode): + tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) + + +@dataclasses.dataclass(eq=False) +class CustomVjpFnWrapper: + f: tp.Callable[..., tp.Any] + ctxtag: str + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args): + broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( + extract.get_broadcast_state(self.ctxtag) + ) + metadata, nondiff_states = broadcast + args = extract.from_tree( + pure_args, + merge_fn=functools.partial( + _custom_vjp_merge_fn, nondiff_states=nondiff_states + ), + ctxtag=self.ctxtag, + ) + + out = self.f(*args) + + args_out = extract.clear_non_graph_nodes(args) + pure_args_out, pure_out = extract.to_tree( + (args_out, out), ctxtag=self.ctxtag + ) + + return pure_args_out, pure_out + + +@dataclasses.dataclass(eq=False) +class FwdFn: + fwd: tp.Callable[..., tp.Any] + ctxtag: str + + def __post_init__(self): + functools.update_wrapper(self, self.fwd) + + def __call__(self, *pure_args): + broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( + extract.get_broadcast_state(self.ctxtag) + ) + metadata, nondiff_states = broadcast + args = extract.from_tree( + pure_args, + merge_fn=functools.partial( + _custom_vjp_merge_fn, nondiff_states=nondiff_states + ), + ctxtag=self.ctxtag, + ) + + out, residual = self.fwd(*args) + + args_out = extract.clear_non_graph_nodes(args) + pure_args_out, pure_out = extract.to_tree( + (args_out, out), ctxtag=self.ctxtag + ) + pure_residual = extract.to_tree(residual) + + return (pure_args_out, pure_out), (metadata, pure_residual) + + +@dataclasses.dataclass(eq=False) +class BwdFn: + bwd: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.bwd) + + def __call__(self, *args): + res: tuple[CustomVjpMetadata, tp.Any] + pure_g: tuple[tp.Any, tp.Any] + *nondiff, res, pure_g = args + metadata, pure_residual = res + nondiff = extract.from_tree(nondiff) + residual = extract.from_tree(pure_residual) + pure_g = jax.tree.map( + lambda x: x.state if isinstance(x, extract.TreeNode) else x, + pure_g, + is_leaf=lambda x: isinstance(x, extract.TreeNode), + ) + + tangent = self.bwd(*nondiff, residual, pure_g) + + def state_to_tree_node(is_tree_node: bool, x): + if is_tree_node: + if not isinstance(x, State): + raise ValueError(f'Expected State, got {type(x)}') + return extract.TreeNode.from_states(x) + return x + + pure_tangent = jax.tree.map( + state_to_tree_node, + metadata.tangent_tree_node_args, + tangent, + is_leaf=lambda x: isinstance(x, State), + ) + return pure_tangent + + +class CustomVjp(tp.Generic[A]): + def __init__( + self, + fun: tp.Callable[..., A], + nondiff_argnums: tuple[int | DiffState, ...], + ): + functools.update_wrapper(self, fun) + jax_nondiff_argnums = tuple( + x.argnum if isinstance(x, DiffState) else x for x in nondiff_argnums + ) + self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}' + self.custom_vjp_fn = jax.custom_vjp( + CustomVjpFnWrapper(fun, self.ctxtag), + nondiff_argnums=jax_nondiff_argnums, + ) + self.nondiff_argnums = nondiff_argnums + self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {} + for argnum in self.nondiff_argnums: + index = argnum.argnum if isinstance(argnum, DiffState) else argnum + if index in self.diff_filter: + raise ValueError(f'argnum {index} is repeated in nondiff_argnums') + self.diff_filter[index] = ( + dataclasses.replace(argnum, argnum=-1) + if isinstance(argnum, DiffState) + else False + ) + + def __getattr__(self, name: str) -> tp.Any: + return getattr(self.custom_vjp_fn, name) + + def __call__( + self, *args: tp.Any, **kwargs: tp.Any + ) -> A: # pytype: disable=invalid-annotation + with graph.update_context(self.ctxtag): + args = resolve_kwargs(self.custom_vjp_fn, args, kwargs) + del kwargs + nondiff_states: deque[extract.GraphDefState] = deque() + arg_filters = tuple( + self.diff_filter.get(i, True) for i in range(len(args)) + ) + pure_args = extract.to_tree( + args, + prefix=arg_filters, + split_fn=functools.partial( + _custom_vjp_split_fn, nondiff_states=nondiff_states + ), + ctxtag=self.ctxtag, + ) + tangent_args = tp.cast( + tuple[tp.Literal[True] | DiffState, ...], + tuple(x for x in arg_filters if x is not False), + ) + tree_node_args = jax.tree.map( + lambda x: isinstance(x, extract.TreeNode), + pure_args, + is_leaf=lambda x: isinstance(x, extract.TreeNode), + ) + tangent_tree_node_args = tuple( + arg + for arg, is_tree_node in zip(args, tree_node_args) + if is_tree_node is not False + ) + metadata = CustomVjpMetadata(tangent_args) + + with extract.broadcast_state(self.ctxtag, (metadata, nondiff_states)): + pure_args_out, pure_out = self.custom_vjp_fn(*pure_args) + + args_out, out = extract.from_tree( + (pure_args_out, pure_out), ctxtag=self.ctxtag + ) + + return out + + def defvjp( + self, + fwd: tp.Callable[..., tuple[A, tp.Any]], + bwd: tp.Callable[..., tuple[tp.Any, ...]], + symbolic_zeros: bool = False, + ) -> None: + """Define a custom VJP rule for the function represented by this instance. + + Args: + fwd: a Python callable representing the forward pass of the custom VJP + rule. When there are no ``nondiff_argnums``, the ``fwd`` function has + the same input signature as the underlying primal function. It should + return as output a pair, where the first element represents the primal + output and the second element represents any "residual" values to store + from the forward pass for use on the backward pass by the function + ``bwd``. Input arguments and elements of the output pair may be arrays + or nested tuples/lists/dicts thereof. + bwd: a Python callable representing the backward pass of the custom VJP + rule. When there are no ``nondiff_argnums``, the ``bwd`` function takes + two arguments, where the first is the "residual" values produced on the + forward pass by ``fwd``, and the second is the output cotangent with the + same structure as the primal function output. The output of ``bwd`` must + be a tuple of length equal to the number of arguments of the primal + function, and the tuple elements may be arrays or nested + tuples/lists/dicts thereof so as to match the structure of the primal + input arguments. + symbolic_zeros: boolean, determining whether to indicate symbolic zeros + to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom + derivative rules to detect when certain inputs, and when certain + output cotangents, are not involved in differentiation. If ``True``: + + * ``fwd`` must accept, in place of each leaf value ``x`` in + the pytree comprising an argument to the original function, + an object (of type + ``jax.custom_derivatives.CustomVJPPrimal``) with two + attributes instead: ``value`` and ``perturbed``. The + ``value`` field is the original primal argument, and + ``perturbed`` is a boolean. The ``perturbed`` bit indicates + whether the argument is involved in differentiation (i.e., + if it is ``False``, then the corresponding Jacobian "column" + is zero). + + * ``bwd`` will be passed objects representing static symbolic zeros in + its cotangent argument in correspondence with unperturbed values; + otherwise, only standard JAX types (e.g. array-likes) are passed. + + Setting this option to ``True`` allows these rules to detect whether + certain inputs and outputs are not involved in differentiation, but at + the cost of special handling. For instance: + + * The signature of ``fwd`` changes, and the objects it is passed cannot + be output from the rule directly. + + * The ``bwd`` rule is passed objects that are not entirely array-like, + and that cannot be passed to most ``jax.numpy`` functions. + + * Any custom pytree nodes involved in the primal function's arguments + must accept, in their unflattening functions, the two-field record + objects that are given as input leaves to the ``fwd`` rule. + + Default ``False``. + + Returns: + None. + + Examples: + + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd) + """ + + self.custom_vjp_fn.defvjp( + FwdFn(fwd, self.ctxtag), + BwdFn(bwd), + symbolic_zeros=symbolic_zeros, + ) + + +@tp.overload +def custom_vjp( + fun: tp.Callable[..., A], + *, + nondiff_argnums: tuple[int | DiffState, ...] = (), +) -> CustomVjp[A]: ... +@tp.overload +def custom_vjp( + *, + nondiff_argnums: tuple[int | DiffState, ...] = (), +) -> tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]: ... +def custom_vjp( + fun: tp.Callable[..., A] | Missing = MISSING, + *, + nondiff_argnums: tuple[int | DiffState, ...] = (), +) -> CustomVjp[A] | tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]: + """Reference aware version of + `jax.custom_vjp `__. + + Example:: + + >>> import jax + >>> import jax.numpy as jnp + >>> from flax import nnx + ... + >>> class Foo(nnx.Module): + ... def __init__(self, x, y): + ... self.x = nnx.Param(x) + ... self.y = nnx.Param(y) + ... + >>> @nnx.custom_vjp + ... def f(m: Foo): + ... return jnp.sin(m.x) * m.y + ... + >>> def f_fwd(m: Foo): + ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) + ... + >>> def f_bwd(res, g): + ... inputs_g, out_g = g + ... cos_x, sin_x, m = res + ... tangent_m = Foo(x=cos_x * out_g * m.y, y=sin_x * out_g) + ... return (tangent_m,) + ... + >>> f.defvjp(f_fwd, f_bwd) + ... + >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) + >>> grads = nnx.grad(f)(m) + ... + >>> jax.tree.map(jnp.shape, grads) + State({ + 'x': VariableState( + type=Param, + value=() + ), + 'y': VariableState( + type=Param, + value=() + ) + }) + + """ + if isinstance(fun, Missing): + return functools.partial(custom_vjp, nondiff_argnums=nondiff_argnums) + return CustomVjp(fun, nondiff_argnums) + + +# ------------------------------- +# remat +# ------------------------------- + + +@tp.overload +def remat( + *, + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, +) -> tp.Callable[[F], F]: ... +@tp.overload +def remat( + f: F, + *, + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, +) -> F: ... +def remat( + f: F | Missing = MISSING, + *, + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + remat, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + + return resolve_kwargs()( + graph.update_context('remat')( + general.split_inputs( + jax.checkpoint( + general.merge_inputs(f, ctxtag='remat'), + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ), + ctxtag='remat', + ), + ) + ) \ No newline at end of file diff --git a/flax/nnx/nnx/transforms/compilation.py b/flax/nnx/nnx/transforms/compilation.py new file mode 100644 index 0000000000..ba4d864d2c --- /dev/null +++ b/flax/nnx/nnx/transforms/compilation.py @@ -0,0 +1,357 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pytype: skip-file + +import dataclasses +import functools +import typing as tp + +from flax.nnx.nnx import ( + extract, + filterlib, + graph, +) +import jax +import jax.core +import jax.stages + +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) + + +class Missing: + pass + + +MISSING = Missing() + +# ------------------------------- +# jit +# ------------------------------- + + +class StateSharding: + def __init__( + self, + filter_sharding: tp.Mapping[filterlib.Filter, tp.Any] + | tp.Iterable[tuple[filterlib.Filter, tp.Any]], + /, + ): + iterable = tuple( + filter_sharding.items() + if isinstance(filter_sharding, tp.Mapping) + else filter_sharding + ) + self._filters = tuple(filter for filter, _ in iterable) + self._shardings = tuple(axis for _, axis in iterable) + + @property + def filters(self) -> tuple[filterlib.Filter, ...]: + return self._filters + + @property + def shardings(self) -> tuple[tp.Any, ...]: + return self._shardings + + def __repr__(self): + return f'StateSharding({dict(zip(self.filters, self.shardings))})' + + def __eq__(self, other): + return ( + isinstance(other, StateSharding) + and self.filters == other.filters + and self.shardings == other.shardings + ) + + def __hash__(self): + return hash((self.filters, self.shardings)) + + +def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): + if isinstance(prefix, StateSharding): + return extract.TreeNode.from_split( + *ctx.split(x, *prefix.filters), metadata=prefix + ) + return extract.TreeNode.from_split(*ctx.split(x)) + + +@dataclasses.dataclass(eq=False) +class JitFn: + f: tp.Callable[..., tp.Any] + in_shardings: tp.Any + out_shardings: tp.Any + kwarg_shardings: tp.Any + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args, **pure_kwargs): + args, kwargs = extract.from_tree((pure_args, pure_kwargs), ctxtag='jit') + + out = self.f(*args, **kwargs) + + args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) + pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( + (args_out, kwargs_out, out), + prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings), + ctxtag='jit', + split_fn=_jit_split_fn, + ) + + return pure_args_out, pure_kwargs_out, pure_out + + +@tp.overload +def jit( + *, + in_shardings: tp.Any = None, + out_shardings: tp.Any = None, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, +) -> tp.Callable[[F], F]: ... +@tp.overload +def jit( + fun: F, + *, + in_shardings: tp.Any = None, + out_shardings: tp.Any = None, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, +) -> F: ... +def jit( + fun: F | Missing = MISSING, + *, + in_shardings: tp.Any = None, + out_shardings: tp.Any = None, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, +) -> F | tp.Callable[[F], F]: + """ + Lifted version of ``jax.jit`` that can handle Modules / graph nodes as + arguments. + + Args: + fun: Function to be jitted. ``fun`` should be a pure function, as + side-effects may only be executed once. + + The arguments and return value of ``fun`` should be arrays, + scalars, or (nested) standard Python containers (tuple/list/dict) thereof. + Positional arguments indicated by ``static_argnums`` can be anything at + all, provided they are hashable and have an equality operation defined. + Static arguments are included as part of a compilation cache key, which is + why hash and equality operators must be defined. + + JAX keeps a weak reference to ``fun`` for use as a compilation cache key, + so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` + objects will already satisfy this requirement. + in_shardings: Pytree of structure matching that of arguments to ``fun``, + with all actual arguments replaced by resource assignment specifications. + It is also valid to specify a pytree prefix (e.g. one value in place of a + whole subtree), in which case the leaves get broadcast to all values in + that subtree. + + The ``in_shardings`` argument is optional. JAX will infer the shardings + from the input :py:class:`jax.Array`'s and defaults to replicating the input + if the sharding cannot be inferred. + + The valid resource assignment specifications are: + - :py:class:`Sharding`, which will decide how the value + will be partitioned. With this, using a mesh context manager is not + required. + - :py:obj:`None`, will give JAX the freedom to choose whatever sharding + it wants. + For in_shardings, JAX will mark is as replicated but this behavior + can change in the future. + For out_shardings, we will rely on the XLA GSPMD partitioner to + determine the output shardings. + + The size of every dimension has to be a multiple of the total number of + resources assigned to it. This is similar to pjit's in_shardings. + out_shardings: Like ``in_shardings``, but specifies resource + assignment for function outputs. This is similar to pjit's + out_shardings. + + The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` + will use GSPMD's sharding propagation to figure out what the sharding of the + output(s) should be. + static_argnums: An optional int or collection of ints that specify which + positional arguments to treat as static (compile-time constant). + Operations that only depend on static arguments will be constant-folded in + Python (during tracing), and so the corresponding argument values can be + any Python object. + + Static arguments should be hashable, meaning both ``__hash__`` and + ``__eq__`` are implemented, and immutable. Calling the jitted function + with different values for these constants will trigger recompilation. + Arguments that are not arrays or containers thereof must be marked as + static. + + If neither ``static_argnums`` nor ``static_argnames`` is provided, no + arguments are treated as static. If ``static_argnums`` is not provided but + ``static_argnames`` is, or vice versa, JAX uses + :code:`inspect.signature(fun)` to find any positional arguments that + correspond to ``static_argnames`` + (or vice versa). If both ``static_argnums`` and ``static_argnames`` are + provided, ``inspect.signature`` is not used, and only actual + parameters listed in either ``static_argnums`` or ``static_argnames`` will + be treated as static. + static_argnames: An optional string or collection of strings specifying + which named arguments to treat as static (compile-time constant). See the + comment on ``static_argnums`` for details. If not + provided but ``static_argnums`` is set, the default is based on calling + ``inspect.signature(fun)`` to find corresponding named arguments. + donate_argnums: Specify which positional argument buffers are "donated" to + the computation. It is safe to donate argument buffers if you no longer + need them once the computation has finished. In some cases XLA can make + use of donated buffers to reduce the amount of memory needed to perform a + computation, for example recycling one of your input buffers to store a + result. You should not reuse buffers that you donate to a computation, JAX + will raise an error if you try to. By default, no argument buffers are + donated. + + If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no + arguments are donated. If ``donate_argnums`` is not provided but + ``donate_argnames`` is, or vice versa, JAX uses + :code:`inspect.signature(fun)` to find any positional arguments that + correspond to ``donate_argnames`` + (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are + provided, ``inspect.signature`` is not used, and only actual + parameters listed in either ``donate_argnums`` or ``donate_argnames`` will + be donated. + + For more details on buffer donation see the + `FAQ `_. + donate_argnames: An optional string or collection of strings specifying + which named arguments are donated to the computation. See the + comment on ``donate_argnums`` for details. If not + provided but ``donate_argnums`` is set, the default is based on calling + ``inspect.signature(fun)`` to find corresponding named arguments. + keep_unused: If `False` (the default), arguments that JAX determines to be + unused by `fun` *may* be dropped from resulting compiled XLA executables. + Such arguments will not be transferred to the device nor provided to the + underlying executable. If `True`, unused arguments will not be pruned. + device: This is an experimental feature and the API is likely to change. + Optional, the Device the jitted function will run on. (Available devices + can be retrieved via :py:func:`jax.devices`.) The default is inherited + from XLA's DeviceAssignment logic and is usually to use + ``jax.devices()[0]``. + backend: This is an experimental feature and the API is likely to change. + Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or + ``'tpu'``. + inline: Specify whether this function should be inlined into enclosing + jaxprs (rather than being represented as an application of the xla_call + primitive with its own subjaxpr). Default False. + + Returns: + A wrapped version of ``fun``, set up for just-in-time compilation. + """ + + if isinstance(fun, Missing): + return functools.partial( + jit, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + ) + kwarg_shardings = None + jax_in_shardings = jax.tree.map( + lambda x: extract.TreeNode.from_prefixes(x.shardings, metadata=x) + if isinstance(x, StateSharding) + else x, + in_shardings, + ) + jax_out_shardings = jax.tree.map( + lambda x: extract.TreeNode.from_prefixes(x.shardings, metadata=x) + if isinstance(x, StateSharding) + else x, + out_shardings, + ) + + jitted_fn = jax.jit( + JitFn(fun, in_shardings, out_shardings, kwarg_shardings), + in_shardings=jax_in_shardings, + out_shardings=(jax_in_shardings, kwarg_shardings, jax_out_shardings), # type: ignore + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + ) + + @functools.wraps(fun) + @graph.update_context('jit') + def jit_wrapper(*args, **kwargs): + pure_args, pure_kwargs = extract.to_tree( + (args, kwargs), + prefix=(in_shardings, kwarg_shardings), + split_fn=_jit_split_fn, + ctxtag='jit', + ) + pure_args_out, pure_kwargs_out, pure_out = jitted_fn( + *pure_args, **pure_kwargs + ) + _args_out, _kwargs_out, out = extract.from_tree( + (pure_args_out, pure_kwargs_out, pure_out), ctxtag='jit' + ) + return out + + jit_wrapper.inner = jitted_fn # type: ignore + + return jit_wrapper # type: ignore diff --git a/flax/nnx/nnx/transforms/deprecated.py b/flax/nnx/nnx/transforms/deprecated.py new file mode 100644 index 0000000000..c7e61de098 --- /dev/null +++ b/flax/nnx/nnx/transforms/deprecated.py @@ -0,0 +1,1953 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pytype: skip-file +from __future__ import annotations + +import dataclasses +import functools +import typing as tp + +from flax import struct +from flax.core.frozen_dict import FrozenDict +from flax.nnx.nnx import extract, filterlib, graph, rnglib, spmd, variables +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.proxy_caller import DelayedAccessor +from flax.nnx.nnx.state import State +from flax.nnx.nnx.transforms.transforms import LiftedModule +from flax.typing import Leaf +import jax +from jax._src.tree_util import broadcast_prefix +import jax.core +import jax.numpy as jnp +import jax.stages +from flax import nnx + +A = tp.TypeVar('A') +C = tp.TypeVar('C') +B = tp.TypeVar('B') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) +M = tp.TypeVar('M', bound=Module) +MA = tp.TypeVar('MA', bound=Module) +N = tp.TypeVar('N', bound=Module) +StrInt = tp.TypeVar('StrInt', str, int) +AxisName = tp.Hashable +Leaves = tp.List[Leaf] +Index = int + + +class Missing: + pass + + +MISSING = Missing() + + +# ------------------------------- +# vmap +# ------------------------------- +class _VmapForkStates(tp.NamedTuple): + split_keys: State + split_counts: State + broadcast_keys: State + broadcast_counts: State + + +def _get_axis_sizes(pytree, axes): + axes = broadcast_prefix(axes, pytree, is_leaf=lambda x: x is None) + leaves = jax.tree_util.tree_leaves(pytree) + axis_sizes = { + leaf.shape[axis] for axis, leaf in zip(axes, leaves) if axis is not None + } + return axis_sizes + + +def _fork_vmap_keys( + state: State, + split_filter: filterlib.Filter, + num_splits: int, +) -> _VmapForkStates: + split_keys, split_counts, broadcast_keys, broadcast_counts = state.split( + filterlib.All(split_filter, rnglib.RngKey), + filterlib.All(split_filter, rnglib.RngCount), + rnglib.RngKey, + rnglib.RngCount, + ) + + def split_key(key: tp.Any, count: tp.Any) -> jax.Array: + if not isinstance(key, jax.Array): + raise TypeError(f'key must be a jax.Array, got {type(key)}') + if not isinstance(count, jax.Array): + raise TypeError(f'count must be a jax.Array, got {type(count)}') + + key = jax.random.fold_in(key, count) + return jax.random.split(key, num_splits) + + split_keys_leaves, split_keys_treedef = jax.tree.flatten(split_keys) + split_counts_leaves, split_counts_treedef = jax.tree.flatten(split_counts) + + if len(split_keys_leaves) != len(split_counts_leaves): + raise ValueError( + 'split_keys and split_counts must have the same number of leaves', + f'got {len(split_keys_leaves)} and {len(split_counts_leaves)}', + ) + + split_keys_leaves = [ + split_key(key, count) + for key, count in zip(split_keys_leaves, split_counts_leaves) + ] + split_counts_leaves = [ + jnp.full((num_splits,), 0, dtype=jnp.uint32) for _ in split_counts_leaves + ] + split_keys = jax.tree.unflatten(split_keys_treedef, split_keys_leaves) + split_counts = jax.tree.unflatten(split_counts_treedef, split_counts_leaves) + + return _VmapForkStates( + split_keys, split_counts, broadcast_keys, broadcast_counts + ) + + +def _backup_vmap_keys(node: tp.Any, /): + backups: list[ + tuple[graph.PathParts, rnglib.RngStream, jax.Array, jax.Array] + ] = [] + for path, stream in graph.iter_graph(node): + if isinstance(stream, rnglib.RngStream): + backups.append((path, stream, stream.key.value, stream.count.value)) + return backups + + +def _restore_vmap_keys( + backups: list[tuple[graph.PathParts, rnglib.RngStream, jax.Array, jax.Array]], + split_rngs: filterlib.Filter, + /, +): + predicate_fn = filterlib.to_predicate(split_rngs) + for path, stream, key, count in backups: + stream.key.value = key + count_path = (*path, 'count') + if predicate_fn(count_path, stream.count.to_state()): + # restore count only if it was split + # add 1 to reflect the split + stream.count.value = count + 1 + + +def vmap_fn( + args: tuple[tp.Any, ...], + kwargs: dict[str, tp.Any], + graphdef: GraphDef[tuple[tp.Any, ...]], + split_keys: State, + split_counts: State, + broadcast_keys: State, + broadcast_counts: State, + vectorized_states: list[State], + broadcast_state: State, + transform_metadata: tp.Mapping[str, tp.Any], + state_axes_: list[tuple[filterlib.Filter, int]], + f: tp.Callable[..., tp.Any], + filters: tp.Tuple[filterlib.Filter, ...], + split_rngs: filterlib.Filter, +): + ctx = graph.current_update_context('vmap') + state_axes = dict(state_axes_) + # remove metadata axis name from Variable.sharding + if spmd.PARTITION_NAME in transform_metadata: + vectorized_states = [ + spmd.remove_axis(state, index, transform_metadata) + for state, index in zip(vectorized_states, state_axes.values()) + ] + + # merge module state + input_graph_nodes = ctx.merge( + graphdef, + *vectorized_states, + broadcast_state, + split_keys, + split_counts, + broadcast_keys, + broadcast_counts, + ) + + (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) + + out = f(*args, **kwargs) + + out, output_graph_nodes = extract.extract_graph_nodes(out) + + # split module state + ( + graphdef_out, + rng_state_out, + *vectorized_states_out, + broadcast_state_out, + ) = ctx.split( # type: ignore[misc] + (input_graph_nodes, output_graph_nodes), + rnglib.RngState, + *filters, + ) + + split_keys_out, broadcast_keys_out = rng_state_out.split(split_rngs, ...) + + broadcast_state_out = State.merge(broadcast_state_out, broadcast_keys_out) + + # add metadata axis name to Variable.sharding + if spmd.PARTITION_NAME in transform_metadata: + vectorized_states_out = [ + spmd.add_axis(state, index, transform_metadata) + for state, index in zip(vectorized_states_out, state_axes.values()) + ] + + return ( + graphdef_out, + broadcast_state_out, + vectorized_states_out, + split_keys_out, + out, + ) + + +@tp.overload +def vmap( + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... +@tp.overload +def vmap( + f: F, + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F: ... +def vmap( + f: F | Missing = MISSING, + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + vmap, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) + + vectorized_states_axes = list(state_axes.values()) + vmapped_fn = jax.vmap( + vmap_fn, + in_axes=( + in_axes, # args + in_axes_kwargs, # kwargs + None, # graphdef + 0, # split_keys + 0, # split_counts + None, # broadcast_keys + None, # broadcast_counts + vectorized_states_axes, # vectorized_states + None, # broadcast_state + None, # transform_metadata + None, # states_axes + None, # f + None, # vectorized_states_filters + None, # split_rngs + ), + out_axes=( + None, # graphdef_out + None, # broadcast_state + vectorized_states_axes, + 0, # keys_out + out_axes, # out_axes + ), + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + ) + + @functools.wraps(f) + @graph.update_context('vmap') + def vmap_wrapper(*args, **kwargs): + ctx = graph.current_update_context('vmap') + + (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( + (args, kwargs) + ) + input_rng_streams = _backup_vmap_keys(input_graph_nodes) + + # split module state + filters = (*state_axes.keys(), ...) + graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc] + input_graph_nodes, rnglib.RngState, *filters + ) + + # infer length + axis_sizes: tp.Set[int] = set() + axis_sizes.update(_get_axis_sizes(args, in_axes)) + axis_sizes.update(_get_axis_sizes(kwargs, in_axes_kwargs)) + for state, state_axis in zip(vectorized_states, state_axes.values()): + axis_sizes.update(_get_axis_sizes(state, state_axis)) + + if len(axis_sizes) > 1: + raise ValueError( + 'Inconsistent lengths between state_axes states and ' + f'arguments: {axis_sizes}' + ) + elif len(axis_sizes) == 0: + if axis_size is None: + raise ValueError( + 'Cannot infer length from state_axes states or axes_arg, ' + 'please specify `length`' + ) + _axis_size = axis_size + else: + _axis_size = axis_sizes.pop() + if axis_size is not None and axis_size != _axis_size: + raise ValueError( + f'Specified axis_size {axis_size} is not the same as the' + f' inferred length {_axis_size}' + ) + + split_keys, split_counts, broadcast_keys, broadcast_counts = ( + _fork_vmap_keys( + rng_state, + split_rngs, + _axis_size, + ) + ) + + ( + graphdef_out, + broadcast_state, + vectorized_states, + split_keys_out, + out, + ) = vmapped_fn( + args, + kwargs, + graphdef, + split_keys, + split_counts, + broadcast_keys, + broadcast_counts, + vectorized_states, + broadcast_state, + transform_metadata, + list(state_axes.items()), + f, + filters, + split_rngs, + ) + + _, output_graph_nodes = ctx.merge( + graphdef_out, + *vectorized_states, + broadcast_state, + split_keys_out, + ) + + out = extract.insert_graph_nodes(out, output_graph_nodes) + + _restore_vmap_keys(input_rng_streams, split_rngs) + + return out + + return vmap_wrapper # type: ignore + + +class Vmap(tp.Generic[M], LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + ) -> tp.Callable[..., Vmap[MA]]: + def _create_vmap(*args, **kwargs): + return Vmap( + module_constructor=module_constructor, + in_axes=in_axes, + out_axes=out_axes, + axis_size=axis_size, + axis_name=axis_name, + spmd_axis_name=spmd_axis_name, + # nnx specific + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, + ) + + return _create_vmap + + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.module_constructor = module_constructor + + @functools.partial( + vmap, + in_axes=None, + out_axes=None, + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + in_axes_kwargs=None, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) + def vmap_init(*args, **kwargs): + return module_constructor(*args, **kwargs) + + self.vmap_module = vmap_init(*module_init_args, **module_init_kwargs) + + @functools.partial( + vmap, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) + def vmap_call(module, *args, _nnx_vmap_accessor: DelayedAccessor, **kwargs): + method = _nnx_vmap_accessor(module) + return method(*args, **kwargs) + + self.vmap_call = vmap_call + + @property + def _submodule(self) -> M: + return self.vmap_module + + def _call(self, accessor: DelayedAccessor, *args, **kwargs): + return self.vmap_call( + self._submodule, *args, _nnx_vmap_accessor=accessor, **kwargs + ) + + +# ------------------------------- +# pmap +# ------------------------------- +@struct.dataclass +class PmapInputs: + transform_metadata: tp.Mapping[str, tp.Any] = struct.field(pytree_node=False) + state_axes: tp.Mapping[filterlib.Filter, int] = struct.field( + pytree_node=False + ) + f: tp.Callable[..., tp.Any] = struct.field(pytree_node=False) + filters: tp.Tuple[filterlib.Filter, ...] = struct.field(pytree_node=False) + split_rngs: filterlib.Filter = struct.field(pytree_node=False) + + +def pmap_fn( + args: tuple[tp.Any, ...], + kwargs: dict[str, tp.Any], + graphdef: GraphDef[tuple[tp.Any, ...]], + split_keys: State, + split_counts: State, + broadcast_keys: State, + broadcast_counts: State, + vectorized_states: list[State], + broadcast_state: State, + pmap_inputs: PmapInputs, +): + transform_metadata = pmap_inputs.transform_metadata + state_axes = pmap_inputs.state_axes + f = pmap_inputs.f + filters = pmap_inputs.filters + split_rngs = pmap_inputs.split_rngs + ctx = graph.current_update_context('pmap') + # remove metadata axis name from Variable.sharding + if spmd.PARTITION_NAME in transform_metadata: + vectorized_states = [ + spmd.remove_axis(state, index, transform_metadata) + for state, index in zip(vectorized_states, state_axes.values()) + ] + + # merge module state + input_graph_nodes = ctx.merge( + graphdef, + *vectorized_states, + broadcast_state, + split_keys, + split_counts, + broadcast_keys, + broadcast_counts, + ) + + (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) + + out = f(*args, **kwargs) + + out, output_graph_nodes = extract.extract_graph_nodes(out) + + # split module state + ( + graphdef_out, + rng_state_out, + *vectorized_states_out, + broadcast_state_out, + ) = ctx.split( # type: ignore[misc] + (input_graph_nodes, output_graph_nodes), + rnglib.RngState, + *filters, + ) + + not_keys_out, split_keys_out, broadcast_keys_out = rng_state_out.split( + rnglib.NotKey, split_rngs, ... + ) + + broadcast_state_out = State.merge( + broadcast_state_out, broadcast_keys_out, not_keys_out + ) + + # add metadata axis name to Variable.sharding + if spmd.PARTITION_NAME in transform_metadata: + vectorized_states_out = [ + spmd.add_axis(state, index, transform_metadata) + for state, index in zip(vectorized_states_out, state_axes.values()) + ] + + return ( + graphdef_out, + broadcast_state_out, + vectorized_states_out, + split_keys_out, + out, + ) + + +@tp.overload +def pmap( + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... +@tp.overload +def pmap( + f: F, + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F: ... +def pmap( + f: F | Missing = MISSING, + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + pmap, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) + if static_broadcasted_argnums: + raise NotImplementedError( + 'static_broadcasted_argnums is not yet supported in nnx.pmap' + ) + if donate_argnums != (): + raise NotImplementedError('donate_argnums is not yet supported in nnx.pmap') + + if global_arg_shapes is not None: + raise NotImplementedError( + 'global_arg_shapes is not yet supported in nnx.pmap' + ) + + vectorized_states_axes = list(state_axes.values()) + + pmapped_fn = jax.pmap( + pmap_fn, + axis_name=axis_name, + in_axes=( + in_axes, # args_axes + in_axes_kwargs, # kwargs_axes + None, # graphdef_axes + 0, # split_keys_axes + None, # split_counts_axes + None, # broadcast_keys_axes + None, # broadcast_counts_axes + vectorized_states_axes, # vectorized_states_axes + None, # broadcast_state_axes + None, # pmap_inputs_axes + ), # type: ignore + out_axes=( + None, # graphdef_out_axes + None, # broadcast_state_axes + vectorized_states_axes, + 0, # keys_axes_out + out_axes, # out_axes + ), # type: ignore + devices=devices, + backend=backend, + axis_size=axis_size, + ) + + @functools.wraps(f) + @graph.update_context('pmap') + def pmap_wrapper(*args, **kwargs): + ctx = graph.current_update_context('pmap') + + (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( + (args, kwargs) + ) + input_rng_streams = rnglib.backup_keys(input_graph_nodes) + + # split module state + filters = (*state_axes.keys(), ...) + graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc] + input_graph_nodes, rnglib.RngState, *filters + ) + + # infer length + axis_sizes: tp.Set[int] = set() + axis_sizes.update(_get_axis_sizes(args, in_axes)) + axis_sizes.update(_get_axis_sizes(kwargs, in_axes_kwargs)) + for state, state_axis in zip(vectorized_states, state_axes.values()): + axis_sizes.update(_get_axis_sizes(state, state_axis)) + + if len(axis_sizes) > 1: + raise ValueError( + 'Inconsistent lengths between state_axes states and ' + f'arguments: {axis_sizes}' + ) + elif len(axis_sizes) == 0: + if axis_size is None: + raise ValueError( + 'Cannot infer length from state_axes states or axes_arg, ' + 'please specify `length`' + ) + _axis_size = axis_size + else: + _axis_size = axis_sizes.pop() + if axis_size is not None and axis_size != _axis_size: + raise ValueError( + f'Specified axis_size {axis_size} is not the same as the' + f' inferred length {_axis_size}' + ) + + split_keys, split_counts, broadcast_keys, broadcast_counts = rnglib.fork( + rng_state, + split_rngs, + _axis_size, + ) + + ( + graphdef_out, + broadcast_state, + vectorized_states, + split_keys_out, + out, + ) = pmapped_fn( + args, + kwargs, + graphdef, + split_keys, + split_counts, + broadcast_keys, + broadcast_counts, + vectorized_states, + broadcast_state, + PmapInputs( + transform_metadata=transform_metadata, + state_axes=state_axes, + f=f, + filters=filters, + split_rngs=split_rngs, + ), + ) + + _, output_graph_nodes = ctx.merge( + graphdef_out, + *vectorized_states, + broadcast_state, + split_keys_out, + ) + + out = extract.insert_graph_nodes(out, output_graph_nodes) + + rnglib.restore_rngs(input_rng_streams) + + return out + + return pmap_wrapper # type: ignore + + +class Pmap(tp.Generic[M], LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + ) -> tp.Callable[..., Pmap[MA]]: + def _create_pmap(*args, **kwargs): + return Pmap( + module_constructor=module_constructor, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + # nnx specific + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, + ) + + return _create_pmap + + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.module_constructor = module_constructor + + @pmap( + axis_name=axis_name, + in_axes=None, + out_axes=None, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=(), + global_arg_shapes=None, + in_axes_kwargs=None, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) + def pmap_init(*args, **kwargs): + return module_constructor(*args, **kwargs) + + self.pmap_module = pmap_init(*module_init_args, **module_init_kwargs) + + @pmap( + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) + def pmap_call(module, *args, _nnx_vmap_accessor: DelayedAccessor, **kwargs): + method = _nnx_vmap_accessor(module) + return method(*args, **kwargs) + + self.pmap_call = pmap_call + + @property + def _submodule(self) -> M: + return self.pmap_module + + def _call(self, accessor: DelayedAccessor, *args, **kwargs): + return self.pmap_call( + self._submodule, *args, _nnx_vmap_accessor=accessor, **kwargs + ) + + +# ------------------------------- +# scan +# ------------------------------- + + +@dataclasses.dataclass(frozen=True) +class FlatDef(tp.Generic[A]): + type: type[A] + treedef: jax.tree_util.PyTreeDef + flat_axes: list[int | None] + + +jax.tree_util.register_static(FlatDef) + + +def _transpose_tree(tree: A, axes, /, *, move_front: bool) -> A: + flatdef, flat_transposes, _ = _transpose_and_split( + tree, axes, allow_none=False, move_front=move_front + ) + return flatdef.treedef.unflatten(flat_transposes) + + +def _transpose_and_split( + tree: A, axes, /, *, allow_none: bool = True, move_front: bool = True +) -> tuple[ + FlatDef[A], + list[jax.Array | None], + list[tp.Any], +]: + flat_axes: list[int | None] = broadcast_prefix( + axes, tree, is_leaf=lambda x: x is None + ) + flat_tree, treedef = jax.tree.flatten(tree) + + flat_broadcasts: list[tp.Any] = [] + flat_transposes: list[jax.Array | None] = [] + + for i, (axis, node) in enumerate(zip(flat_axes, flat_tree)): + if axis is None: + if not allow_none: + raise ValueError('None axis not allowed') + + flat_broadcasts.append(node) + flat_transposes.append(None) + else: + if not isinstance(node, jax.Array): + raise TypeError( + f'Expected a jax.Array, got {type(node).__name__} for axis {axis}' + ) + # normalize axis + if axis < 0: + if axis < -len(node.shape): + raise ValueError( + f'Axis {axis} out of bounds for array with shape {node.shape}' + ) + axis = len(node.shape) + axis + flat_axes[i] = axis + + if node.shape == (): + raise ValueError(f'Cannot map over a scalar array, got {node}') + elif axis >= len(node.shape): + raise ValueError( + f'Axis {axis} out of bounds for array with shape {node.shape}' + ) + + if move_front: + node = jnp.moveaxis(node, axis, 0) + else: + node = jnp.moveaxis(node, 0, axis) + flat_broadcasts.append(None) + flat_transposes.append(node) + + flatdef = FlatDef(type(tree), treedef, flat_axes) + + return flatdef, flat_transposes, flat_broadcasts + + +def _unflatten_splits( + flatdef: FlatDef[A], + flat_transposes: list[jax.Array | None], + flat_broadcasts: list[tp.Any] | None = None, + /, + *, + allow_none: bool = True, +) -> A: + flat_axes = flatdef.flat_axes + treedef = flatdef.treedef + if flat_broadcasts is None: + if allow_none: + raise ValueError('flat_broadcasts must be provided if allow_none is True') + flat_broadcasts = [None] * len(flat_axes) + + flat_tree = [] + for axis, transpose, broadcast in zip( + flat_axes, flat_transposes, flat_broadcasts + ): + if axis is None: + if not allow_none: + raise ValueError('None axis not allowed') + flat_tree.append(broadcast) + else: + if transpose is None: + raise ValueError('None transpose not allowed') + flat_tree.append(transpose) + + tree = treedef.unflatten(flat_tree) + return tree + + +def _extract_carry_arg( + args: tuple[tp.Any, ...], carry_argnum: int, / +) -> tuple[tp.Any, tuple[tp.Any, ...]]: + # extract carry arg + if len(args) < carry_argnum + 1: + raise TypeError( + f'Expected at least {carry_argnum + 1} positional arguments, ' + f'got {len(args)}' + ) + + args_ = list(args) + carry_arg = args_[carry_argnum] + args_[carry_argnum] = None + args = tuple(args_) + + return carry_arg, args + + +def _insert_carry_arg( + args: tuple[tp.Any, ...], carry_argnum: int, carry_arg: tp.Any, / +) -> tuple[tp.Any, ...]: + args_ = list(args) + args_[carry_argnum] = carry_arg + args = tuple(args_) + + return args + + +@struct.dataclass +class ScanBroadcasts(tp.Generic[C, B]): + flatdef: FlatDef[ + tuple[tuple[tp.Any, ...], dict[str, tp.Any], list[State]] + ] = struct.field(pytree_node=False) + flat_carry: list[tp.Any] = struct.field(pytree_node=True) + graphdef: GraphDef[tuple[tp.Any, ...]] = struct.field(pytree_node=False) + filters: tuple[filterlib.Filter, ...] = struct.field(pytree_node=False) + f: tp.Callable[..., tuple[C, B] | C] = struct.field(pytree_node=False) + # options + carry_argnum: int = struct.field(pytree_node=False) + state_axes: tp.Mapping[filterlib.Filter, int] = struct.field( + pytree_node=False + ) + split_rngs: filterlib.Filter = struct.field(pytree_node=False) + transform_metadata: tp.Mapping[str, tp.Any] = struct.field(pytree_node=False) + scan_output: bool = struct.field(pytree_node=False) + + +def scan_fn( + carry: tuple[ + State, # split_rng_state + State, # broadcast_rng_state + State, # carry_state + tp.Any, # carry_arg + ScanBroadcasts[C, B], # broadcasts + ], + scan: tuple[ + list[jax.Array | None], # flat_scan + ], +): + split_rng_state, broadcast_rng_state, carry_state, carry_arg, broadcasts = ( + carry + ) + (flat_scan,) = scan + flatdef = broadcasts.flatdef + flat_carry = broadcasts.flat_carry + graphdef, filters = broadcasts.graphdef, broadcasts.filters + f = broadcasts.f + ctx = graph.current_update_context('scan') + + # merge args and kwargs + args, kwargs, scan_states = _unflatten_splits(flatdef, flat_scan, flat_carry) + # remove metadata axis name from Variable.sharding + if spmd.PARTITION_NAME in broadcasts.transform_metadata: + scan_states = [ + spmd.remove_axis(state, index, broadcasts.transform_metadata) + for state, index in zip(scan_states, broadcasts.state_axes.values()) + ] + + # insert carry arg + args = _insert_carry_arg(args, broadcasts.carry_argnum, carry_arg) + + # merge module state + input_graph_nodes = ctx.merge( + graphdef, *scan_states, carry_state, split_rng_state, broadcast_rng_state + ) + (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) + + out = f(*args, **kwargs) + + if broadcasts.scan_output: + if not isinstance(out, tuple) or len(out) != 2: + raise ValueError( + 'Expected a tuple of length 2 as the output of the scan function, ' + f'got {out}' + ) + out = tp.cast(tuple[C, B], out) # type: ignore[invalid-annotation] + carry_arg_out, scan_args_out = out + else: + out = tp.cast(C, out) # type: ignore[invalid-annotation] + carry_arg_out = out + scan_args_out = None + + ((carry_arg_out, scan_args_out), output_graph_nodes) = ( + extract.extract_graph_nodes((carry_arg_out, scan_args_out)) + ) + + # split module state + ( + graphdef_out, + rng_state_out, + *scan_states_out, + carry_state_out, + ) = ctx.split( # type: ignore[misc] + (input_graph_nodes, output_graph_nodes), + rnglib.RngState, + *filters, + ) + + split_rng_state_out, broadcast_rng_state_out = rng_state_out.split( + broadcasts.split_rngs, ... + ) + + def _extract_carry_state(state: State, /): + if 1 in state: + raise ValueError( + f'Cannot add new carry state during scan, got {state[1]}' + ) + if 0 in state: + _state = state[0] + assert isinstance(_state, State) + state = _state + + return state + + carry_state_out = _extract_carry_state(carry_state_out) + split_rng_state_out = _extract_carry_state(split_rng_state_out) + broadcast_rng_state_out = _extract_carry_state(broadcast_rng_state_out) + + # override broadcast_rng_state_out to keep the same state + # for the next iteration + broadcast_rng_state_out = broadcast_rng_state + + # add metadata axis name to Variable.sharding + if spmd.PARTITION_NAME in broadcasts.transform_metadata: + scan_states_out = [ + spmd.add_axis(state, index, broadcasts.transform_metadata) + for state, index in zip(scan_states_out, broadcasts.state_axes.values()) + ] + + carry_out = ( + split_rng_state_out, + broadcast_rng_state_out, + carry_state_out, + carry_arg_out, + broadcasts, + ) + scan_out = (graphdef_out, scan_args_out, scan_states_out) + + return carry_out, scan_out + + +@tp.overload +def scan( + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 0, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + scan_output: bool = True, +) -> p.Callable[[F], F]: ... +@tp.overload +def scan( + f: F, + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 0, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + scan_output: bool = True, +) -> F: ... +def scan( + f: F | Missing = MISSING, + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 0, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + scan_output: bool = True, +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + scan, length=length, reverse=reverse, unroll=unroll + ) + + @functools.wraps(f) + @graph.update_context('scan') + def scan_apply_wrapper(*args, **kwargs): + # extract nodes + (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( + (args, kwargs) + ) + input_rng_streams = rnglib.backup_keys(input_graph_nodes) + + # extract carry arg + carry_arg, args = _extract_carry_arg(args, carry_argnum) + + ctx = graph.current_update_context('scan') + # split module state + filters = (*state_axes.keys(), ...) + graphdef, rng_state, *scan_states, carry_state = ctx.split( # type: ignore[misc] + input_graph_nodes, rnglib.RngState, *filters + ) + + # transpose axes arg + flatdef, flat_scan, flat_carry = _transpose_and_split( + (args, kwargs, scan_states), + (in_axes, in_axes_kwargs, list(state_axes.values())), + ) + + # infer length + lengths: set[int] = { + x.shape[0] # type: ignore + for x, axis in zip(flat_scan, flatdef.flat_axes) + if axis is not None + } + + if len(lengths) > 1: + raise ValueError( + 'Inconsistent lengths between state_axes states and ' + f'arguments: {lengths}' + ) + elif len(lengths) == 0: + if length is None: + raise ValueError( + 'Cannot infer length from state_axes states or axes_arg, ' + 'please specify `length`' + ) + infered_length = length + else: + infered_length = lengths.pop() + if length is not None and length != infered_length: + raise ValueError( + f'Specified length {length} is not the same as the inferred ' + f'length {infered_length}' + ) + + # split rng state + split_rng_state, broadcast_rng_state = rng_state.split(split_rngs, ...) + + broadcasts = ScanBroadcasts( + flatdef, + flat_carry, + graphdef, + filters, + f, + # options + carry_argnum, + state_axes, + split_rngs, + transform_metadata, + scan_output, + ) + carry = ( + split_rng_state, + broadcast_rng_state, + carry_state, + carry_arg, + broadcasts, + ) + scan = (flat_scan,) + + carry_out, scan_out = jax.lax.scan( + scan_fn, + carry, + scan, + length=infered_length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + ) + ( + split_rng_state_out, + broadcast_rng_state_out, + carry_state_out, + carry_arg_out, + broadcasts, + ) = carry_out + graphdef_out, scan_args_out, scan_states_out = scan_out + + scan_args_out, scan_states_out = _transpose_tree( + (scan_args_out, scan_states_out), + (out_axes, list(state_axes.values())), + move_front=False, + ) + + if carry_state_out: + carry_state_out = State({0: carry_state_out._mapping}) + if split_rng_state_out: + split_rng_state_out = State({0: split_rng_state_out._mapping}) + if broadcast_rng_state_out: + broadcast_rng_state_out = State({0: broadcast_rng_state_out._mapping}) + + _, output_graph_nodes = ctx.merge( + graphdef_out, + *scan_states_out, + carry_state_out, + split_rng_state_out, + broadcast_rng_state_out, + ) + + carry_arg_out, scan_args_out = extract.insert_graph_nodes( + (carry_arg_out, scan_args_out), output_graph_nodes + ) + + rnglib.restore_rngs(input_rng_streams) + + if scan_output: + scan_args_out = tp.cast(B, scan_args_out) + return carry_arg_out, scan_args_out + else: + return carry_arg_out + + return scan_apply_wrapper # type: ignore + + +class Scan(tp.Generic[M], LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 1, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + scan_output: bool = True, + ) -> tp.Callable[..., Scan[MA]]: + def _create_scan(*args, **kwargs): + return Scan( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + # base api + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + # extended api + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, + out_axes=out_axes, + carry_argnum=carry_argnum, + # nnx specific + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + scan_output=scan_output, + ) + + return _create_scan + + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 1, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + scan_output: bool = True, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.module_constructor = module_constructor + # use Vmap to handle initialisation + vmapped_module = Vmap.constructor( + module_constructor, + in_axes=in_axes, + out_axes=None, + axis_name=None, + axis_size=length, + spmd_axis_name=None, + state_axes=state_axes, + split_rngs=split_rngs, + in_axes_kwargs=in_axes_kwargs, + transform_metadata=transform_metadata, + )(*module_init_args, **module_init_kwargs) + self.scan_module = vmapped_module.vmap_module + + @functools.partial( + scan, + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, + out_axes=out_axes, + carry_argnum=carry_argnum, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + scan_output=scan_output, + ) + def scan_call(module, *args, _nnx_scan_accessor: DelayedAccessor, **kwargs): + method = _nnx_scan_accessor(module) + return method(*args, **kwargs) + + self.scan_call = scan_call + + @property + def _submodule(self) -> M: + return self.scan_module + + def _call( + self, accessor: DelayedAccessor, *args, **kwargs + ) -> tuple[tp.Any, tp.Any]: + return self.scan_call( + self._submodule, *args, _nnx_scan_accessor=accessor, **kwargs + ) + + +# ------------------------------- +# remat +# ------------------------------- + + +class Remat(tp.Generic[M], LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, + ) -> tp.Callable[..., Remat[MA]]: + def create_remat(*args, **kwargs): + return Remat( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + + return create_remat + + def __init__( + self, + *, + module_constructor: tp.Callable[..., M], + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.module_constructor = module_constructor + self.remat_module = self.module_constructor( + *module_init_args, **module_init_kwargs + ) + + @nnx.remat( + prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy + ) + def remat_call(module, *args): + accessor: DelayedAccessor + *args, accessor = args + method = accessor(module) + return method(*args) + + self.rem_call = remat_call + + @property + def _submodule(self) -> M: + return self.remat_module + + def _call(self, accessor: DelayedAccessor, *args) -> tp.Any: + return self.rem_call(self._submodule, *args, accessor) + + +# ------------------------------- +# grad +# ------------------------------- + + +def grad_fn(*args): + f: tp.Callable[..., tp.Any] + graphdef: GraphDef[tuple[dict[int, tp.Any], tuple[tp.Any, ...]]] + non_diff_state: State + has_aux: bool + diff_args: list[int] + ctx = graph.current_update_context('grad') + *args, f, graphdef, non_diff_state, has_aux, diff_args = args + + # rebuild diff_state from substates in args + diff_state = State({}) + for i in diff_args: + diff_state[i] = args[i] + diff_state: graph.GraphState = State({0: diff_state.raw_mapping}) + + diff_graph_nodes, input_nodes = ctx.merge( + graphdef, diff_state, non_diff_state + ) + + # add nodes to the args + for i, arg in diff_graph_nodes.items(): + args[i] = arg + + # add other nodes to the args + args = extract.insert_graph_nodes(args, input_nodes) + + out = f(*args) + + out, out_nodes = extract.extract_graph_nodes(out) + + graphdef_out, state_out = ctx.split((input_nodes, out_nodes)) + + if has_aux: + loss, aux = out + out = (loss, (graphdef_out, state_out, aux)) + else: + out = (out, (graphdef_out, state_out)) + + return out + + +def _grad_general( + f: tp.Callable[..., tp.Any], + argnums: int | tp.Sequence[int], + has_aux: bool, + holomorphic: bool, + allow_int: bool, + reduce_axes: tp.Sequence[AxisName], + wrt: filterlib.Filter, + return_value: bool, +) -> tp.Callable[..., tp.Any]: + @graph.update_context('grad') + def grad_wrapper(*args): + ctx: graph.UpdateContext = graph.current_update_context('grad') + _argnums = _normalize_sequence(argnums) + diff_graph_nodes: dict[int, tp.Any] = { + i: arg + for i, arg in enumerate(args) + if i in _argnums and graph.is_node(arg) + } + args, input_nodes = extract.extract_graph_nodes(args) + args = list(args) + + def only_diff(path: tuple, value: tp.Any) -> bool: + # diff_graph_nodes is the first element in the tuple + return path[0] == 0 + + graphdef, diff_state, non_diff_state = ctx.split( + (diff_graph_nodes, input_nodes), filterlib.All(wrt, only_diff), ... + ) # type: ignore[misc] + + # extract diff_state substates into the args + diff_args: list[int] = [] + if 0 in diff_state: + for i, diff_substate in diff_state[0].items(): # type: ignore + assert isinstance(i, int) + args[i] = diff_substate + diff_args.append(i) + transform = jax.value_and_grad if return_value else jax.grad + + _argnums = _argnums[0] if len(_argnums) == 1 else _argnums + + out = transform( + grad_fn, + argnums=_argnums, + has_aux=True, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + )(*args, f, graphdef, non_diff_state, has_aux, diff_args) + + if return_value: + if has_aux: + (loss, (graphdef_out, state_out, aux)), grads = out + out = (loss, aux), grads + else: + (loss, (graphdef_out, state_out)), grads = out + out = loss, grads + else: + if has_aux: + grads, (graphdef_out, state_out, aux) = out + out = grads, aux + else: + out, (graphdef_out, state_out) = out + + input_nodes, out_nodes = ctx.merge(graphdef_out, state_out) + + out = extract.insert_graph_nodes(out, out_nodes) + return out + + return grad_wrapper + + +def grad( + f: tp.Callable[..., tp.Any], + argnums: int | tp.Sequence[int] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), + *, + wrt: filterlib.Filter = variables.Param, +) -> tp.Callable[..., tp.Any]: + """Lifted version of ``jax.grad`` that can handle Modules / graph nodes as + arguments. + + The differentiable state of each graph node is defined by the `wrt` filter, + which by default is set to `nnx.Param`. Internally the ``State`` of + graph nodes is extracted, filtered according to `wrt` filter, and + passed to the underlying ``jax.grad`` function. The gradients + of graph nodes are of type ``State``. + + Example:: + + >>> from flax import nnx + >>> import jax.numpy as jnp + ... + >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + >>> x = jnp.ones((1, 2)) + >>> y = jnp.ones((1, 3)) + ... + >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) + >>> grad_fn = nnx.grad(loss_fn, wrt=nnx.Param) + ... + >>> grads = grad_fn(m, x, y) + >>> jax.tree.map(jnp.shape, grads) + State({ + 'bias': VariableState( + type=Param, + value=(3,) + ), + 'kernel': VariableState( + type=Param, + value=(2, 3) + ) + }) + + Args: + fun: Function to be differentiated. Its arguments at positions specified by + ``argnums`` should be arrays, scalars, graph nodes or standard Python + containers. Argument arrays in the positions specified by ``argnums`` must + be of inexact (i.e., floating-point or complex) type. It should return a + scalar (which includes arrays with shape ``()`` but not arrays with shape + ``(1,)`` etc.) + argnums: Optional, integer or sequence of integers. Specifies which + positional argument(s) to differentiate with respect to (default 0). + has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be + holomorphic. If True, inputs and outputs must be complex. Default False. + allow_int: Optional, bool. Whether to allow differentiating with + respect to integer valued inputs. The gradient of an integer input will + have a trivial vector-space dtype (float0). Default False. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fun`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + gradient will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a + function that computes the total gradient while ``grad(f)`` will create + one that computes the per-example gradient. + wrt: Optional, filterlib.Filter. Filter to extract the differentiable state + of each graph node. Default is `nnx.Param`. + + """ + + return _grad_general( + f, + argnums, + has_aux, + holomorphic, + allow_int, + reduce_axes, + wrt, + return_value=False, + ) + + +def value_and_grad( + f: tp.Callable[..., tp.Any], + argnums: int | tp.Sequence[int] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), + *, + wrt: filterlib.Filter = variables.Param, +) -> tp.Callable[..., tp.Any]: + return _grad_general( + f, + argnums, + has_aux, + holomorphic, + allow_int, + reduce_axes, + wrt, + return_value=True, + ) + + +class Grad(tp.Generic[M], LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), + return_value: bool = False, + *, + wrt: filterlib.Filter = variables.Param, + ) -> tp.Callable[..., Grad[MA]]: + def _create_grad(*args, **kwargs): + return Grad( + module_constructor=module_constructor, + wrt=wrt, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + return_value=return_value, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, + ) + + return _create_grad + + def __init__( + self, + module_constructor: tp.Callable[..., M], + argnums: int | tp.Sequence[int] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), + *, + wrt: filterlib.Filter = variables.Param, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.module_constructor = module_constructor + self.grad_module = self.module_constructor( + *module_init_args, **module_init_kwargs + ) + + @functools.partial( + grad, + argnums=argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + wrt=wrt, + ) + def grad_call_apply(module, *args): + *args, accessor = args + method = accessor(module) + return method(*args) + + self.grad_apply = grad_call_apply + + @property + def _submodule(self) -> M: + return self.grad_module + + def _call(self, accessor: DelayedAccessor, *args) -> tp.Any: + return self.grad_apply(self.grad_module, *args, accessor) + + +# ------------------------------- +# jit +# ------------------------------- + + +class Jit(tp.Generic[M], LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + *, + in_shardings: tp.Any = None, + out_shardings: tp.Any = None, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + ) -> tp.Callable[..., Jit[MA]]: + def _create_jit(*args, **kwargs): + return Jit( + module_constructor=module_constructor, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, + ) + + return _create_jit + + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + in_shardings: tp.Any = None, + out_shardings: tp.Any = None, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + @functools.partial( + nnx.jit, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + ) + def jit_call_module( + module, *args, _nnx_jit_accessor: DelayedAccessor, **kwargs + ): + method = _nnx_jit_accessor(module) + return method(*args, **kwargs) + + self.jitted_fn = jit_call_module + self.module_constructor = module_constructor + self.jit_module = self.module_constructor( + *module_init_args, **module_init_kwargs + ) + + @property + def _submodule(self) -> M: + return self.jit_module + + def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: + out = self.jitted_fn( + self.jit_module, *args, _nnx_jit_accessor=accessor, **kwargs + ) + return out diff --git a/flax/nnx/nnx/transforms/general.py b/flax/nnx/nnx/transforms/general.py index f874619cf8..59c4d0e287 100644 --- a/flax/nnx/nnx/transforms/general.py +++ b/flax/nnx/nnx/transforms/general.py @@ -29,13 +29,10 @@ import functools import typing as tp -from flax import struct from flax.nnx.nnx import ( extract, graph, ) -from flax.nnx.nnx.module import GraphDef -from flax.nnx.nnx.state import State A = tp.TypeVar('A') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) @@ -51,33 +48,21 @@ class Missing: # ------------------------------- -class ArgState(extract.ExtractionIndex, extract.ExtractableStates): - _graphdef: GraphDef[tp.Any] = struct.field(pytree_node=False) - state: State = struct.field(pytree_node=True) - - @property - def graphdef(self) -> GraphDef[tp.Any]: - return self._graphdef - - @property - def states(self) -> tp.Iterable[State]: - yield self.state - @tp.overload def split_inputs( *, - ctx_tag: str = 'split_merge_inputs', + ctxtag: str = 'split_merge_inputs', ) -> tp.Callable[[F], F]: ... @tp.overload def split_inputs( f: F, *, - ctx_tag: str = 'split_merge_inputs', + ctxtag: str = 'split_merge_inputs', ) -> F: ... def split_inputs( f: F | Missing = MISSING, *, - ctx_tag: str = 'split_merge_inputs', + ctxtag: str = 'split_merge_inputs', ) -> F | tp.Callable[[F], F]: """Takes in a function that contains graph nodes in the inputs and outputs, and returns a function that replaces the graph nodes with some jax-compatible data @@ -85,7 +70,7 @@ def split_inputs( Args: f: The function to be transformed. - ctx_tag: The context tag to be used for the transformation. Defaults to + ctxtag: The context tag to be used for the transformation. Defaults to 'split_merge_inputs'. Returns: @@ -178,32 +163,14 @@ def split_inputs( `Functional API `__. """ if isinstance(f, Missing): - return functools.partial(split_inputs, ctx_tag=ctx_tag) # type: ignore[return-value] + return functools.partial(split_inputs, ctxtag=ctxtag) # type: ignore[return-value] - @graph.update_context(ctx_tag) + @graph.update_context(ctxtag) @functools.wraps(f) def split_inputs_wrapper(*args): - ctx = graph.current_update_context(ctx_tag) - args, input_graph_nodes = extract.extract_graph_nodes(args) - graphdef, states = ctx.split(input_graph_nodes) - args = extract.replace_indexes( - args, - lambda x: ArgState( - x.index, - graphdef, - states[x.index], # type: ignore - ), - ) - args_out, out = f(*args) - arg_states_out = extract.extract_indexes((args_out, out), types=ArgState) - - if arg_states_out: - graphdef_out, states_out = extract.merge_extractable_states( - arg_states_out - ) - output_nodes = ctx.merge(graphdef_out, states_out) - out = extract.insert_graph_nodes(out, output_nodes) - + pure_args = extract.to_tree(args, ctxtag=ctxtag) + pure_args_out, pure_out = f(*pure_args) + args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag=ctxtag) return out return split_inputs_wrapper # type: ignore @@ -211,18 +178,18 @@ def split_inputs_wrapper(*args): @tp.overload def merge_inputs( *, - ctx_tag: str = 'split_merge_inputs', + ctxtag: str = 'split_merge_inputs', ) -> tp.Callable[[F], F]: ... @tp.overload def merge_inputs( f: F, *, - ctx_tag: str = 'split_merge_inputs', + ctxtag: str = 'split_merge_inputs', ) -> F: ... def merge_inputs( f: F | Missing = MISSING, *, - ctx_tag: str = 'split_merge_inputs', + ctxtag: str = 'split_merge_inputs', ) -> F | tp.Callable[[F], F]: """Takes in a function that contains jax-compatible data structures in the inputs and outputs, and returns a function that replaces the jax-compatible @@ -231,7 +198,7 @@ def merge_inputs( Args: f: The function to be transformed. - ctx_tag: The context tag to be used for the transformation. Defaults to + ctxtag: The context tag to be used for the transformation. Defaults to 'split_merge_inputs'. Returns: @@ -240,36 +207,14 @@ def merge_inputs( For more information and examples, see :func:`split_inputs`. """ if isinstance(f, Missing): - return functools.partial(merge_inputs, ctx_tag=ctx_tag) # type: ignore[return-value] + return functools.partial(merge_inputs, ctxtag=ctxtag) # type: ignore[return-value] @functools.wraps(f) - def merge_inputs_wrapper(*args): - ctx = graph.current_update_context(ctx_tag) - arg_states = extract.extract_indexes(args, types=ArgState) - - if arg_states: - graphdef, states = extract.merge_extractable_states(arg_states) - inputs_graph_nodes = ctx.merge(graphdef, states) - args = extract.insert_graph_nodes(args, inputs_graph_nodes) - + def merge_inputs_wrapper(*pure_args): + args = extract.from_tree(pure_args, ctxtag=ctxtag) out = f(*args) - - (args_out, out), output_graph_nodes = extract.extract_graph_nodes( - (args, out) - ) - - graphdef_out, states_out = ctx.split(output_graph_nodes) - - def replace_index(x: extract.Extractable): - return ArgState( - x.index, - graphdef_out, - states_out[x.index], # type: ignore - ) - - out = extract.replace_indexes(out, replace_index) - args_out = extract.replace_indexes(args_out, replace_index, clear=True) - - return args_out, out + args_out = extract.clear_non_graph_nodes(args) + pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag=ctxtag) + return pure_args_out, pure_out return merge_inputs_wrapper # type: ignore diff --git a/flax/nnx/nnx/transforms/iteration.py b/flax/nnx/nnx/transforms/iteration.py new file mode 100644 index 0000000000..e7dc03c6e4 --- /dev/null +++ b/flax/nnx/nnx/transforms/iteration.py @@ -0,0 +1,1162 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pytype: skip-file + +from collections import deque +import dataclasses +import functools +import typing as tp + +import numpy as np + +from flax import struct +from flax.core.frozen_dict import FrozenDict +from flax.nnx.nnx import extract, filterlib, graph, spmd +from flax.nnx.nnx.module import Module +from flax.nnx.nnx.state import State +from flax.nnx.nnx.transforms.transforms import resolve_kwargs +from flax.typing import Leaf, PytreeDeque +import jax +from jax._src.tree_util import broadcast_prefix +import jax.core +import jax.numpy as jnp +import jax.stages + +A = tp.TypeVar('A') +C = tp.TypeVar('C') +B = tp.TypeVar('B') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) +M = tp.TypeVar('M', bound=Module) +MA = tp.TypeVar('MA', bound=Module) +N = tp.TypeVar('N', bound=Module) +StrInt = tp.TypeVar('StrInt', str, int) +AxisName = tp.Hashable +Leaves = tp.List[Leaf] +Index = int + + +class Missing: + pass + + +MISSING = Missing() + +class Carry: + pass + + +# ------------------------------- +# vmap +# ------------------------------- + + +class StateAxes: + def __init__( + self, + filter_axes: tp.Mapping[filterlib.Filter, Index | type[Carry] | None] + | tp.Iterable[tuple[filterlib.Filter, Index | type[Carry] | None]], + /, + ): + iterable = tuple( + filter_axes.items() + if isinstance(filter_axes, tp.Mapping) + else filter_axes + ) + self._filters = tuple(filter for filter, _ in iterable) + self._axes = tuple(axis for _, axis in iterable) + + @property + def filters(self) -> tuple[filterlib.Filter, ...]: + return self._filters + + @property + def axes(self) -> tuple[Index | type[Carry] | None, ...]: + return self._axes + + def __repr__(self): + return f'StateAxes({dict(zip(self.filters, self.axes))})' + + def __eq__(self, other): + return ( + isinstance(other, StateAxes) + and self.filters == other.filters + and self.axes == other.axes + ) + + def __hash__(self): + return hash((self.filters, self.axes)) + + +AxisFn = tp.Callable[ + [extract.GraphDefState, int, tp.Mapping], extract.GraphDefState +] + + +def _update_variable_sharding_metadata( + tree, transform_metadata, axis_fn: AxisFn +): + def _update_axes_fn(tree_node): + if isinstance(tree_node, extract.TreeNode) and isinstance( + tree_node.metatata, StateAxes + ): + graphdef_states_out: list[extract.GraphDefState] = [] + for graphdef_state, axis in zip( + tree_node.graphdef_states, tree_node.metatata.axes + ): + assert isinstance(graphdef_state, extract.GraphDefState) + if isinstance(axis, int): + graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) + graphdef_states_out.append(graphdef_state) + return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) + return tree_node + + return jax.tree.map( + _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.TreeNode) + ) + + +def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x): + if isinstance(prefix, StateAxes): + return extract.TreeNode.from_split( + *ctx.split(x, *prefix.filters), metadata=prefix + ) + return extract.TreeNode.from_split(*ctx.split(x)) + + +@dataclasses.dataclass(eq=False) +class VmapFn: + f: tp.Callable[..., tp.Any] + transform_metadata: tp.Mapping[str, tp.Any] + in_axes: tp.Any + out_axes: tp.Any + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args: tuple[tp.Any, ...]): + if spmd.PARTITION_NAME in self.transform_metadata: + pure_args = _update_variable_sharding_metadata( + pure_args, self.transform_metadata, spmd.remove_axis + ) + args = extract.from_tree(pure_args, ctxtag='vmap') + + out = self.f(*args) + + args_out = extract.clear_non_graph_nodes(args) + pure_args_out, pure_out = extract.to_tree( + (args_out, out), + prefix=(self.in_axes, self.out_axes), + split_fn=_vmap_split_fn, + ctxtag='vmap', + ) + if spmd.PARTITION_NAME in self.transform_metadata: + pure_args_out, pure_out = _update_variable_sharding_metadata( + (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis + ) + return pure_args_out, pure_out + + +@tp.overload +def vmap( + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... +@tp.overload +def vmap( + f: F, + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F: ... +def vmap( + f: F | Missing = MISSING, + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F | tp.Callable[[F], F]: + """Reference-aware version of `jax.vmap `__. + + Args: + f: Function to be mapped over additional axes. + in_axes: An integer, None, or sequence of values specifying which input + array axes to map over (see `jax.vmap `__). + In addition to integers and None, :class:`StateAxes` can be used to control how + graph nodes like Modules are vectorized by specifying the axes to be + applied to substates of the graph node given a `Filter `__. + out_axes: An integer, None, or pytree indicating where the mapped axis should appear + in the output (see `jax.vmap `__). + axis_name: Optional, a hashable Python object used to identify the mapped + axis so that parallel collectives can be applied. + axis_size: Optional, an integer indicating the size of the axis to be + mapped. If not provided, the mapped axis size is inferred from arguments. + + Returns: + Batched/vectorized version of ``f`` with arguments that correspond to + those of ``f``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``f``, but + with extra array axes at positions indicated by ``out_axes``. + + Example:: + + >>> from flax import nnx + >>> from jax import random, numpy as jnp + ... + >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + >>> x = jnp.ones((5, 2)) + ... + >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) + ... def forward(model, x): + ... return model(x) + ... + >>> y = forward(model, x) + >>> y.shape + (5, 3) + + >>> class LinearEnsemble(nnx.Module): + ... def __init__(self, num, rngs): + ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) + ... + >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) + >>> x = jnp.ones((2,)) + ... + >>> @nnx.vmap(in_axes=(0, None), out_axes=0) + ... def forward(model, x): + ... return jnp.dot(x, model.w.value) + ... + >>> y = forward(model, x) + >>> y.shape + (5, 3) + + To control control how graph node substates are vectorized, ``StateAxes`` + can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be + applied to each substate given a filter. The following example shows how to + share the parameters between the ensemble members which keeping different + batch statistics and dropout random state:: + + >>> class Foo(nnx.Module): + ... def __init__(self): + ... self.a = nnx.Param(jnp.arange(4)) + ... self.b = nnx.BatchStat(jnp.arange(4)) + ... + >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) + >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) + ... def mul(foo): + ... return foo.a * foo.b + ... + >>> foo = Foo() + >>> y = mul(foo) + >>> y + Array([[0, 0, 0, 0], + [0, 1, 2, 3], + [0, 2, 4, 6], + [0, 3, 6, 9]], dtype=int32) + """ + if isinstance(f, Missing): + return functools.partial( + vmap, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + transform_metadata=transform_metadata, + ) + + jax_in_axes = jax.tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + in_axes, + ) + jax_out_axes = jax.tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + out_axes, + ) + vmapped_fn = jax.vmap( + VmapFn(f, transform_metadata, in_axes, out_axes), + in_axes=jax_in_axes, + out_axes=(jax_in_axes, jax_out_axes), + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + ) + + @functools.wraps(f) + @graph.update_context('vmap') + def vmap_wrapper(*args, **kwargs): + args = resolve_kwargs(f, args, kwargs) + pure_args = extract.to_tree( + args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap' + ) + pure_args_out, pure_out = vmapped_fn(*pure_args) + _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap') + return out + + return vmap_wrapper # type: ignore + + +# ------------------------------- +# pmap +# ------------------------------- + + +@dataclasses.dataclass(eq=False) +class PmapFn: + f: tp.Callable[..., tp.Any] + transform_metadata: tp.Mapping[str, tp.Any] + in_axes: tp.Any + out_axes: tp.Any + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args: tuple[tp.Any, ...]): + if spmd.PARTITION_NAME in self.transform_metadata: + pure_args = _update_variable_sharding_metadata( + pure_args, self.transform_metadata, spmd.remove_axis + ) + args = extract.from_tree(pure_args, ctxtag='pmap') + + out = self.f(*args) + + args_out = extract.clear_non_graph_nodes(args) + pure_args_out, pure_out = extract.to_tree( + (args_out, out), + prefix=(self.in_axes, self.out_axes), + split_fn=_vmap_split_fn, + ctxtag='pmap', + ) + if spmd.PARTITION_NAME in self.transform_metadata: + pure_args_out, pure_out = _update_variable_sharding_metadata( + (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis + ) + return pure_args_out, pure_out + + +@tp.overload +def pmap( + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... +@tp.overload +def pmap( + f: F, + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F: ... +def pmap( + f: F | Missing = MISSING, + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F | tp.Callable[[F], F]: + """Reference-aware version of `jax.vmap `__. + + Args: + f: Function to be mapped over additional axes. + in_axes: An integer, None, or sequence of values specifying which input + array axes to map over (see `jax.vmap `__). + In addition to integers and None, :class:`StateAxes` can be used to control how + graph nodes like Modules are vectorized by specifying the axes to be + applied to substates of the graph node given a `Filter `__. + out_axes: An integer, None, or pytree indicating where the mapped axis should appear + in the output (see `jax.vmap `__). + axis_name: Optional, a hashable Python object used to identify the mapped + axis so that parallel collectives can be applied. + axis_size: Optional, an integer indicating the size of the axis to be + mapped. If not provided, the mapped axis size is inferred from arguments. + + Returns: + Batched/vectorized version of ``f`` with arguments that correspond to + those of ``f``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``f``, but + with extra array axes at positions indicated by ``out_axes``. + + Example:: + + >>> from flax import nnx + >>> from jax import random, numpy as jnp + ... + >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + >>> x = jnp.ones((5, 2)) + ... + >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) + ... def forward(model, x): + ... return model(x) + ... + >>> y = forward(model, x) + >>> y.shape + (5, 3) + + >>> class LinearEnsemble(nnx.Module): + ... def __init__(self, num, rngs): + ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) + ... + >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) + >>> x = jnp.ones((2,)) + ... + >>> @nnx.vmap(in_axes=(0, None), out_axes=0) + ... def forward(model, x): + ... return jnp.dot(x, model.w.value) + ... + >>> y = forward(model, x) + >>> y.shape + (5, 3) + + To control control how graph node substates are vectorized, ``StateAxes`` + can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be + applied to each substate given a filter. The following example shows how to + share the parameters between the ensemble members which keeping different + batch statistics and dropout random state:: + + >>> class Foo(nnx.Module): + ... def __init__(self): + ... self.a = nnx.Param(jnp.arange(4)) + ... self.b = nnx.BatchStat(jnp.arange(4)) + ... + >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) + >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) + ... def mul(foo): + ... return foo.a * foo.b + ... + >>> foo = Foo() + >>> y = mul(foo) + >>> y + Array([[0, 0, 0, 0], + [0, 1, 2, 3], + [0, 2, 4, 6], + [0, 3, 6, 9]], dtype=int32) + """ + if isinstance(f, Missing): + return functools.partial( + pmap, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + transform_metadata=transform_metadata, + ) + + jax_in_axes = jax.tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + in_axes, + ) + jax_out_axes = jax.tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + out_axes, + ) + pmapped_fn = jax.pmap( + PmapFn(f, transform_metadata, in_axes, out_axes), + axis_name=axis_name, + in_axes=jax_in_axes, + out_axes=(jax_in_axes, jax_out_axes), + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + ) + + @functools.wraps(f) + @graph.update_context('pmap') + def vmap_wrapper(*args): + pure_args = extract.to_tree( + args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='pmap' + ) + pure_args_out, pure_out = pmapped_fn(*pure_args) + _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='pmap') + return out + + return vmap_wrapper # type: ignore + + +# ------------------------------- +# scan +# ------------------------------- + + +@dataclasses.dataclass(frozen=True) +class FlatDef(tp.Generic[A]): + type: type[A] + treedef: jax.tree_util.PyTreeDef + flat_axes: list[int | None] + + +jax.tree_util.register_static(FlatDef) + + +def _transpose_tree(tree: A, axes, /, *, move_front: bool) -> A: + flatdef, flat_transposes, _ = _transpose_and_split( + tree, axes, allow_none=False, move_front=move_front + ) + return flatdef.treedef.unflatten(flat_transposes) + + +def _transpose_and_split( + tree: A, axes, /, *, allow_none: bool = True, move_front: bool = True +) -> tuple[ + FlatDef[A], + list[jax.Array | None], + list[tp.Any], +]: + flat_axes: list[int | None] = broadcast_prefix( + axes, tree, is_leaf=lambda x: x is None + ) + flat_tree, treedef = jax.tree.flatten(tree) + + flat_broadcasts: list[tp.Any] = [] + flat_transposes: list[jax.Array | None] = [] + + for i, (axis, node) in enumerate(zip(flat_axes, flat_tree)): + if axis is None: + if not allow_none: + raise ValueError('None axis not allowed') + + flat_broadcasts.append(node) + flat_transposes.append(None) + else: + if not isinstance(node, jax.Array): + raise TypeError( + f'Expected a jax.Array, got {type(node).__name__} for axis {axis}' + ) + # normalize axis + if axis < 0: + if axis < -len(node.shape): + raise ValueError( + f'Axis {axis} out of bounds for array with shape {node.shape}' + ) + axis = len(node.shape) + axis + flat_axes[i] = axis + + if node.shape == (): + raise ValueError(f'Cannot map over a scalar array, got {node}') + elif axis >= len(node.shape): + raise ValueError( + f'Axis {axis} out of bounds for array with shape {node.shape}' + ) + + if move_front: + node = jnp.moveaxis(node, axis, 0) + else: + node = jnp.moveaxis(node, 0, axis) + flat_broadcasts.append(None) + flat_transposes.append(node) + + flatdef = FlatDef(type(tree), treedef, flat_axes) + + return flatdef, flat_transposes, flat_broadcasts + + +def _unflatten_splits( + flatdef: FlatDef[A], + flat_transposes: list[jax.Array | None], + flat_broadcasts: list[tp.Any] | None = None, + /, + *, + allow_none: bool = True, +) -> A: + flat_axes = flatdef.flat_axes + treedef = flatdef.treedef + if flat_broadcasts is None: + if allow_none: + raise ValueError('flat_broadcasts must be provided if allow_none is True') + flat_broadcasts = [None] * len(flat_axes) + + flat_tree = [] + for axis, transpose, broadcast in zip( + flat_axes, flat_transposes, flat_broadcasts + ): + if axis is None: + if not allow_none: + raise ValueError('None axis not allowed') + flat_tree.append(broadcast) + else: + if transpose is None: + raise ValueError('None transpose not allowed') + flat_tree.append(transpose) + + tree = treedef.unflatten(flat_tree) + return tree + + +def _extract_carry_arg( + args: tuple[tp.Any, ...], carry_argnum: int, / +) -> tuple[tp.Any, tuple[tp.Any, ...]]: + # extract carry arg + if len(args) < carry_argnum + 1: + raise TypeError( + f'Expected at least {carry_argnum + 1} positional arguments, ' + f'got {len(args)}' + ) + + args_ = list(args) + carry_arg = args_[carry_argnum] + args_[carry_argnum] = None + args = tuple(args_) + + return carry_arg, args + + +class Broadcasted(struct.PyTreeNode): + data: tp.Any + + +def _scan_split_in( + carry_deque: PytreeDeque[list[State]], + broadcast_deque: PytreeDeque[list[State]], + broadcast_arrays: PytreeDeque[Broadcasted], + /, + ctx: graph.SplitContext, + path, + prefix, + x, +): + if graph.is_graph_node(x): + vectorized_states: list[State] = [] + carry_states: list[State] = [] + broadcast_states: list[State] = [] + if isinstance(prefix, StateAxes): + graphdef, *states = ctx.split(x, *prefix.filters) + + for state, axis in zip(states, prefix.axes): + if axis is None: + broadcast_states.append(state) + elif isinstance(axis, int): + state = jax.tree.map(lambda x: jnp.moveaxis(x, axis, 0), state) + vectorized_states.append(state) + else: # axis is Carry + carry_states.append(state) + + carry_deque.append(carry_states) + broadcast_deque.append(broadcast_states) + return extract.TreeNode.from_split( + graphdef, *vectorized_states, metadata=prefix + ) + elif isinstance(prefix, int): + graphdef, state = ctx.split(x) + state = jax.tree.map(lambda x: jnp.moveaxis(x, prefix, 0), state) + vectorized_states.append(state) + elif prefix is None: + graphdef, state = ctx.split(x) + broadcast_states.append(state) + vectorized_states.append(State({})) + elif prefix is Carry: + graphdef, state = ctx.split(x) + carry_states.append(state) + vectorized_states.append(State({})) + else: + raise ValueError( + f'Invalid axes {prefix} at path {jax.tree_util.keystr(path)}' + ) + + carry_deque.append(carry_states) + broadcast_deque.append(broadcast_states) + return extract.TreeNode.from_split( + graphdef, *vectorized_states, metadata=prefix + ) + else: + if isinstance(prefix, StateAxes): + raise ValueError( + f'Cannot use StateAxes on non-graph nodes, ' + f'found {prefix} at path {jax.tree_util.keystr(path)}' + ) + elif prefix is Carry: + return x + elif prefix is None: + broadcast_arrays.append(Broadcasted(x)) + return Broadcasted(None) + elif isinstance(prefix, int): + if not isinstance(x, (jax.Array, np.ndarray)): + raise ValueError( + f'Expected an array, got {type(x).__name__} at path ' + f'{jax.tree_util.keystr(path)}' + ) + return jnp.moveaxis(x, prefix, 0) + else: + raise ValueError( + f'Invalid axes {prefix} at path {jax.tree_util.keystr(path)}' + ) + + +def _scan_split_out( + carry_deque: PytreeDeque[list[State]], + broadcast_deque: PytreeDeque[list[State]], + /, + ctx: graph.SplitContext, + path: extract.KeyPath, + prefix, + x, +): + assert isinstance(path[0], jax.tree_util.SequenceKey) + is_input_arg = path[0].idx == 0 + + if graph.is_graph_node(x): + vectorized_states: list[State] = [] + carry_states: list[State] = [] + broadcast_states: list[State] = [] + if isinstance(prefix, StateAxes): + graphdef, *states = ctx.split(x, *prefix.filters) + + for state, filter, axis in zip(states, prefix.filters, prefix.axes): + if axis is None: + if is_input_arg: + broadcast_states.append(state) + elif state: + raise ValueError( + f'Cannot broadcast output state. ' + f'Got filter {filter} and axis None at path {jax.tree_util.keystr(path)}' + ) + elif isinstance(axis, int): + vectorized_states.append(state) + else: # axis is Carry + if is_input_arg: + carry_states.append(state) + elif state: + raise ValueError( + f'Cannot carry output state. ' + f'Got filter {filter} and axis {axis} at path {jax.tree_util.keystr(path)}' + ) + if is_input_arg: + carry_deque.append(carry_states) + broadcast_deque.append(broadcast_states) + return extract.TreeNode.from_split( + graphdef, *vectorized_states, metadata=prefix + ) + elif isinstance(prefix, int): + graphdef, state = ctx.split(x) + vectorized_states.append(state) + elif prefix is None: + graphdef, state = ctx.split(x) + if is_input_arg: + broadcast_states.append(state) + vectorized_states.append(State({})) + elif state: + raise ValueError( + f'Cannot broadcast output state. ' + f'Got out_axes=None at path {jax.tree_util.keystr(path)}' + ) + elif prefix is Carry: + graphdef, state = ctx.split(x) + if is_input_arg: + carry_states.append(state) + vectorized_states.append(State({})) + elif state: + raise ValueError( + f'Cannot carry output state. ' + f'Got out_axes=carry at path {jax.tree_util.keystr(path)}' + ) + else: + raise ValueError( + f'Invalid axes {prefix} at path {jax.tree_util.keystr(path)}' + ) + + if is_input_arg: + carry_deque.append(carry_states) + broadcast_deque.append(broadcast_states) + return extract.TreeNode.from_split( + graphdef, *vectorized_states, metadata=prefix + ) + else: + if isinstance(prefix, StateAxes): + raise ValueError( + f'Cannot use StateAxes on non-graph nodes, ' + f'found {prefix} at path {jax.tree_util.keystr(path)}' + ) + elif prefix is Carry: + return x + elif prefix is None: + if not is_input_arg: + raise ValueError( + f'Cannot broadcast outputs. ' + f'Got out_axes=None at path {jax.tree_util.keystr(path)}' + ) + return Broadcasted(None) + elif isinstance(prefix, int): + return x + else: + raise ValueError( + f'Invalid axes {prefix} at path {jax.tree_util.keystr(path)}' + ) + + +def _scan_merge_in( + carry_deque: PytreeDeque[list[State]], + broadcast_deque: PytreeDeque[list[State]], + broadcast_arrays: PytreeDeque[Broadcasted], + /, + ctx: graph.MergeContext, + path, + prefix, + x, +): + if isinstance(x, extract.TreeNode): + carry_states = carry_deque.popleft() + broadcast_states = broadcast_deque.popleft() + return ctx.merge(x.graphdef, *x.states, *carry_states, *broadcast_states) + elif isinstance(x, Broadcasted): + assert x.data is None + return broadcast_arrays.popleft().data + else: + return x + + +def _scan_merge_out( + carry_deque: PytreeDeque[list[State]], + broadcast_deque: PytreeDeque[list[State]], + /, + ctx: graph.MergeContext, + path, + prefix, + x, +): + assert isinstance(path[0], jax.tree_util.SequenceKey) + is_input_arg = path[0].idx == 0 + + if isinstance(x, extract.TreeNode): + states: list[State] = [] + if is_input_arg: + carry_states = deque(carry_deque.popleft()) + broadcast_states = deque(broadcast_deque.popleft()) + else: + carry_states = deque[State]() + broadcast_states = deque[State]() + if isinstance(prefix, StateAxes): + vectorized_states = deque(x.states) + assert len(prefix.axes) == len(vectorized_states) + len( + carry_states + ) + len(broadcast_states) + for axis in prefix.axes: + if isinstance(axis, int): + state = vectorized_states.popleft() + state = jax.tree.map(lambda x: jnp.moveaxis(x, 0, axis), state) + states.append(state) + elif axis is None: + states.append(broadcast_states.popleft()) + else: # axis is Carry + states.append(carry_states.popleft()) + assert not vectorized_states and not carry_states and not broadcast_states + elif isinstance(prefix, int): + state = jax.tree.map(lambda x: jnp.moveaxis(x, 0, prefix), x.state) + states.extend((state, *carry_states, *broadcast_states)) + elif prefix is None: + assert is_input_arg + states.extend(broadcast_states) + elif prefix is Carry: + assert is_input_arg + states.extend(carry_states) + else: + raise ValueError( + f'Invalid axes {prefix} at path {jax.tree_util.keystr(path)}' + ) + + return ctx.merge(x.graphdef, *states) + else: + if isinstance(prefix, StateAxes): + raise ValueError( + f'Cannot use StateAxes on non-graph nodes, ' + f'found {prefix} at path {jax.tree_util.keystr(path)}' + ) + elif prefix is Carry: + return x + elif prefix is None: + return x + elif isinstance(prefix, int): + if not isinstance(x, (jax.Array, np.ndarray)): + raise ValueError( + f'Expected an array, got {type(x).__name__} at path ' + f'{jax.tree_util.keystr(path)}' + ) + return jnp.moveaxis(x, 0, prefix) + else: + raise ValueError( + f'Invalid axes {prefix} at path {jax.tree_util.keystr(path)}' + ) + + +@dataclasses.dataclass(eq=False) +class ScanFn: + f: tp.Callable[..., tp.Any] + carry_argnum: int + in_axes: tp.Any + out_axes: tp.Any + transform_metadata: tp.Mapping[str, tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__( + self, + carry: tuple[ + tp.Any, # carry_arg + PytreeDeque[list[State]], # carry_deque + PytreeDeque[list[State]], # broadcast_deque + PytreeDeque[Broadcasted], # broadcast_arrays + ], + pure_args: list[tp.Any], + ): + pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays = carry + pure_args[self.carry_argnum] = pure_carry_arg + broadcast_deque_out = PytreeDeque(broadcast_deque) + broadcast_arrays_out = PytreeDeque(broadcast_arrays) + + if spmd.PARTITION_NAME in self.transform_metadata: + pure_args = _update_variable_sharding_metadata( + pure_args, self.transform_metadata, spmd.remove_axis + ) + + args = extract.from_tree( + pure_args, + prefix=self.in_axes, + merge_fn=functools.partial( + _scan_merge_in, carry_deque, broadcast_deque, broadcast_arrays + ), + is_leaf=lambda x: isinstance(x, (extract.TreeNode, Broadcasted)), + map_non_graph_nodes=True, + ctxtag='scan', + ) + assert not carry_deque and not broadcast_deque and not broadcast_arrays + + out = self.f(*args) + + carry_deque_out = PytreeDeque[list[State]]() + _broadcast_deque_out_tmp = PytreeDeque[list[State]]() # discarded + + args_out = extract.clear_non_graph_nodes(args) + pure_args_out, pure_out = extract.to_tree( + (args_out, out), + prefix=(self.in_axes, self.out_axes), + split_fn=functools.partial( + _scan_split_out, carry_deque_out, _broadcast_deque_out_tmp + ), + map_non_graph_nodes=True, + ctxtag='scan', + ) + if spmd.PARTITION_NAME in self.transform_metadata: + pure_args_out, pure_out = _update_variable_sharding_metadata( + (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis + ) + + pure_carry_arg_out, pure_scan_out = pure_out + carry_out = ( + pure_carry_arg_out, + carry_deque_out, + broadcast_deque_out, + broadcast_arrays_out, + ) + return carry_out, (pure_args_out, pure_scan_out) + + +@tp.overload +def scan( + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: tp.Sequence[tp.Any] = (Carry, 0), + out_axes: tp.Any = 0, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... +@tp.overload +def scan( + f: F, + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: tp.Sequence[tp.Any] = (Carry, 0), + out_axes: tp.Any = 0, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F: ... +def scan( + f: F | Missing = MISSING, + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: tp.Sequence[tp.Any] = (Carry, 0), + out_axes: tp.Any = 0, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + scan, + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + in_axes=in_axes, + out_axes=out_axes, + transform_metadata=transform_metadata, + ) + + carry_argnum: int = -1 + for key, x in jax.tree_util.tree_leaves_with_path(in_axes): + if x is not Carry: + continue + assert isinstance(key[0], jax.tree_util.SequenceKey) + i = key[0].idx + if len(key) >= 2: + raise ValueError( + f'Carry must be used direcly on an input, it cannot be nested. ' + f'Found {in_axes=}' + ) + if carry_argnum >= 0: + raise ValueError(f'Found multiple Carry axes in in_axes: {in_axes}') + carry_argnum = i + if carry_argnum < 0: + raise ValueError(f'No Carry axis specified in in_axes: {in_axes}') + + in_axes = list(in_axes) + + scan_fn = ScanFn( + f, + carry_argnum, + in_axes, + out_axes, + transform_metadata, + ) + + @functools.wraps(f) + @graph.update_context('scan') + def scan_wrapper(*args, **kwargs): + args = list(resolve_kwargs(f, args, kwargs)) + carry_deque = PytreeDeque() + broadcast_deque = PytreeDeque() + broadcast_arrays = PytreeDeque() + pure_args = extract.to_tree( + args, + prefix=in_axes, + split_fn=functools.partial( + _scan_split_in, carry_deque, broadcast_deque, broadcast_arrays + ), + map_non_graph_nodes=True, + ctxtag='scan', + ) + pure_carry_arg = pure_args[carry_argnum] + pure_args[carry_argnum] = None + + carry = (pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays) + + carry_out, (pure_args_out, pure_scan_out) = jax.lax.scan( + scan_fn, + carry, + pure_args, + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + ) + ( + pure_carry_arg_out, + carry_deque_out, + broadcast_deque_out, + broadcast_arrays_out, + ) = carry_out + pure_args_out[carry_argnum] = pure_carry_arg_out + args_out, scan_out = extract.from_tree( + (pure_args_out, pure_scan_out), + prefix=(in_axes, out_axes), + merge_fn=functools.partial( + _scan_merge_out, carry_deque_out, broadcast_deque_out + ), + is_leaf=lambda x: isinstance(x, (extract.TreeNode, Broadcasted)), + map_non_graph_nodes=True, + ctxtag='scan', + ) + carry_out = args_out[carry_argnum] + + return carry_out, scan_out + + return scan_wrapper # type: ignore \ No newline at end of file diff --git a/flax/nnx/nnx/transforms/looping.py b/flax/nnx/nnx/transforms/looping.py deleted file mode 100644 index f8d9b27c18..0000000000 --- a/flax/nnx/nnx/transforms/looping.py +++ /dev/null @@ -1,594 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pytype: skip-file -from __future__ import annotations - -import dataclasses -import functools -import typing as tp - -from flax import struct -from flax.core.frozen_dict import FrozenDict -from flax.nnx.nnx import extract, filterlib, graph, rnglib, spmd -from flax.nnx.nnx.module import GraphDef, Module -from flax.nnx.nnx.proxy_caller import DelayedAccessor -from flax.nnx.nnx.state import State -from flax.nnx.nnx.transforms.parallelization import Vmap -from flax.nnx.nnx.transforms.transforms import LiftedModule -from flax.typing import Leaf -import jax -from jax._src.tree_util import broadcast_prefix -import jax.core -import jax.numpy as jnp -import jax.stages - -A = tp.TypeVar('A') -C = tp.TypeVar('C') -B = tp.TypeVar('B') -F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) -G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) -M = tp.TypeVar('M', bound=Module) -MA = tp.TypeVar('MA', bound=Module) -N = tp.TypeVar('N', bound=Module) -StrInt = tp.TypeVar('StrInt', str, int) -AxisName = tp.Hashable -Leaves = tp.List[Leaf] -Index = int - -# ------------------------------- -# scan -# ------------------------------- - - -@dataclasses.dataclass(frozen=True) -class FlatDef(tp.Generic[A]): - type: type[A] - treedef: jax.tree_util.PyTreeDef - flat_axes: list[int | None] - - -jax.tree_util.register_static(FlatDef) - - -def _transpose_tree(tree: A, axes, /, *, move_front: bool) -> A: - flatdef, flat_transposes, _ = _transpose_and_split( - tree, axes, allow_none=False, move_front=move_front - ) - return flatdef.treedef.unflatten(flat_transposes) - - -def _transpose_and_split( - tree: A, axes, /, *, allow_none: bool = True, move_front: bool = True -) -> tuple[ - FlatDef[A], - list[jax.Array | None], - list[tp.Any], -]: - flat_axes: list[int | None] = broadcast_prefix( - axes, tree, is_leaf=lambda x: x is None - ) - flat_tree, treedef = jax.tree.flatten(tree) - - flat_broadcasts: list[tp.Any] = [] - flat_transposes: list[jax.Array | None] = [] - - for i, (axis, node) in enumerate(zip(flat_axes, flat_tree)): - if axis is None: - if not allow_none: - raise ValueError('None axis not allowed') - - flat_broadcasts.append(node) - flat_transposes.append(None) - else: - if not isinstance(node, jax.Array): - raise TypeError( - f'Expected a jax.Array, got {type(node).__name__} for axis {axis}' - ) - # normalize axis - if axis < 0: - if axis < -len(node.shape): - raise ValueError( - f'Axis {axis} out of bounds for array with shape {node.shape}' - ) - axis = len(node.shape) + axis - flat_axes[i] = axis - - if node.shape == (): - raise ValueError(f'Cannot map over a scalar array, got {node}') - elif axis >= len(node.shape): - raise ValueError( - f'Axis {axis} out of bounds for array with shape {node.shape}' - ) - - if move_front: - node = jnp.moveaxis(node, axis, 0) - else: - node = jnp.moveaxis(node, 0, axis) - flat_broadcasts.append(None) - flat_transposes.append(node) - - flatdef = FlatDef(type(tree), treedef, flat_axes) - - return flatdef, flat_transposes, flat_broadcasts - - -def _unflatten_splits( - flatdef: FlatDef[A], - flat_transposes: list[jax.Array | None], - flat_broadcasts: list[tp.Any] | None = None, - /, - *, - allow_none: bool = True, -) -> A: - flat_axes = flatdef.flat_axes - treedef = flatdef.treedef - if flat_broadcasts is None: - if allow_none: - raise ValueError('flat_broadcasts must be provided if allow_none is True') - flat_broadcasts = [None] * len(flat_axes) - - flat_tree = [] - for axis, transpose, broadcast in zip( - flat_axes, flat_transposes, flat_broadcasts - ): - if axis is None: - if not allow_none: - raise ValueError('None axis not allowed') - flat_tree.append(broadcast) - else: - if transpose is None: - raise ValueError('None transpose not allowed') - flat_tree.append(transpose) - - tree = treedef.unflatten(flat_tree) - return tree - - -def _extract_carry_arg( - args: tuple[tp.Any, ...], carry_argnum: int, / -) -> tuple[tp.Any, tuple[tp.Any, ...]]: - # extract carry arg - if len(args) < carry_argnum + 1: - raise TypeError( - f'Expected at least {carry_argnum + 1} positional arguments, ' - f'got {len(args)}' - ) - - args_ = list(args) - carry_arg = args_[carry_argnum] - args_[carry_argnum] = None - args = tuple(args_) - - return carry_arg, args - - -def _insert_carry_arg( - args: tuple[tp.Any, ...], carry_argnum: int, carry_arg: tp.Any, / -) -> tuple[tp.Any, ...]: - args_ = list(args) - args_[carry_argnum] = carry_arg - args = tuple(args_) - - return args - - -@struct.dataclass -class ScanBroadcasts(tp.Generic[C, B]): - flatdef: FlatDef[ - tuple[tuple[tp.Any, ...], dict[str, tp.Any], list[State]] - ] = struct.field(pytree_node=False) - flat_carry: list[tp.Any] = struct.field(pytree_node=True) - graphdef: GraphDef[tuple[tp.Any, ...]] = struct.field(pytree_node=False) - filters: tuple[filterlib.Filter, ...] = struct.field(pytree_node=False) - f: tp.Callable[..., tuple[C, B] | C] = struct.field(pytree_node=False) - # options - carry_argnum: int = struct.field(pytree_node=False) - state_axes: tp.Mapping[filterlib.Filter, int] = struct.field( - pytree_node=False - ) - split_rngs: filterlib.Filter = struct.field(pytree_node=False) - transform_metadata: tp.Mapping[str, tp.Any] = struct.field(pytree_node=False) - scan_output: bool = struct.field(pytree_node=False) - - -def scan_fn( - carry: tuple[ - State, # split_rng_state - State, # broadcast_rng_state - State, # carry_state - tp.Any, # carry_arg - ScanBroadcasts[C, B], # broadcasts - ], - scan: tuple[ - list[jax.Array | None], # flat_scan - ], -): - split_rng_state, broadcast_rng_state, carry_state, carry_arg, broadcasts = ( - carry - ) - (flat_scan,) = scan - flatdef = broadcasts.flatdef - flat_carry = broadcasts.flat_carry - graphdef, filters = broadcasts.graphdef, broadcasts.filters - f = broadcasts.f - ctx = graph.current_update_context('scan') - - # merge args and kwargs - args, kwargs, scan_states = _unflatten_splits(flatdef, flat_scan, flat_carry) - # remove metadata axis name from Variable.sharding - if spmd.PARTITION_NAME in broadcasts.transform_metadata: - scan_states = [ - spmd.remove_axis(state, index, broadcasts.transform_metadata) - for state, index in zip(scan_states, broadcasts.state_axes.values()) - ] - - # insert carry arg - args = _insert_carry_arg(args, broadcasts.carry_argnum, carry_arg) - - # merge module state - input_graph_nodes = ctx.merge( - graphdef, *scan_states, carry_state, split_rng_state, broadcast_rng_state - ) - (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) - - out = f(*args, **kwargs) - - if broadcasts.scan_output: - if not isinstance(out, tuple) or len(out) != 2: - raise ValueError( - 'Expected a tuple of length 2 as the output of the scan function, ' - f'got {out}' - ) - out = tp.cast(tuple[C, B], out) # type: ignore[invalid-annotation] - carry_arg_out, scan_args_out = out - else: - out = tp.cast(C, out) # type: ignore[invalid-annotation] - carry_arg_out = out - scan_args_out = None - - ((carry_arg_out, scan_args_out), output_graph_nodes) = ( - extract.extract_graph_nodes((carry_arg_out, scan_args_out)) - ) - - # split module state - ( - graphdef_out, - rng_state_out, - *scan_states_out, - carry_state_out, - ) = ctx.split( # type: ignore[misc] - (input_graph_nodes, output_graph_nodes), - rnglib.RngState, - *filters, - ) - - split_rng_state_out, broadcast_rng_state_out = rng_state_out.split( - broadcasts.split_rngs, ... - ) - - def _extract_carry_state(state: State, /): - if 1 in state: - raise ValueError( - f'Cannot add new carry state during scan, got {state[1]}' - ) - if 0 in state: - _state = state[0] - assert isinstance(_state, State) - state = _state - - return state - - carry_state_out = _extract_carry_state(carry_state_out) - split_rng_state_out = _extract_carry_state(split_rng_state_out) - broadcast_rng_state_out = _extract_carry_state(broadcast_rng_state_out) - - # override broadcast_rng_state_out to keep the same state - # for the next iteration - broadcast_rng_state_out = broadcast_rng_state - - # add metadata axis name to Variable.sharding - if spmd.PARTITION_NAME in broadcasts.transform_metadata: - scan_states_out = [ - spmd.add_axis(state, index, broadcasts.transform_metadata) - for state, index in zip(scan_states_out, broadcasts.state_axes.values()) - ] - - carry_out = ( - split_rng_state_out, - broadcast_rng_state_out, - carry_state_out, - carry_arg_out, - broadcasts, - ) - scan_out = (graphdef_out, scan_args_out, scan_states_out) - - return carry_out, scan_out - - -def scan( - f: F, - *, - length: int | None = None, - reverse: bool = False, - unroll: int | bool = 1, - _split_transpose: bool = False, - # extended api - in_axes: int | None | tp.Sequence[tp.Any] = 0, - in_axes_kwargs: tp.Any = 0, - out_axes: tp.Any = 0, - carry_argnum: int = 0, - # nnx specific - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - scan_output: bool = True, -) -> F: - @functools.wraps(f) - @graph.update_context('scan') - def scan_apply_wrapper(*args, **kwargs): - # extract nodes - (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( - (args, kwargs) - ) - input_rng_streams = rnglib.backup_keys(input_graph_nodes) - - # extract carry arg - carry_arg, args = _extract_carry_arg(args, carry_argnum) - - ctx = graph.current_update_context('scan') - # split module state - filters = (*state_axes.keys(), ...) - graphdef, rng_state, *scan_states, carry_state = ctx.split( # type: ignore[misc] - input_graph_nodes, rnglib.RngState, *filters - ) - - # transpose axes arg - flatdef, flat_scan, flat_carry = _transpose_and_split( - (args, kwargs, scan_states), - (in_axes, in_axes_kwargs, list(state_axes.values())), - ) - - # infer length - lengths: set[int] = { - x.shape[0] # type: ignore - for x, axis in zip(flat_scan, flatdef.flat_axes) - if axis is not None - } - - if len(lengths) > 1: - raise ValueError( - 'Inconsistent lengths between state_axes states and ' - f'arguments: {lengths}' - ) - elif len(lengths) == 0: - if length is None: - raise ValueError( - 'Cannot infer length from state_axes states or axes_arg, ' - 'please specify `length`' - ) - infered_length = length - else: - infered_length = lengths.pop() - if length is not None and length != infered_length: - raise ValueError( - f'Specified length {length} is not the same as the inferred ' - f'length {infered_length}' - ) - - # split rng state - split_rng_state, broadcast_rng_state = rng_state.split(split_rngs, ...) - - broadcasts = ScanBroadcasts( - flatdef, - flat_carry, - graphdef, - filters, - f, - # options - carry_argnum, - state_axes, - split_rngs, - transform_metadata, - scan_output, - ) - carry = ( - split_rng_state, - broadcast_rng_state, - carry_state, - carry_arg, - broadcasts, - ) - scan = (flat_scan,) - - carry_out, scan_out = jax.lax.scan( - scan_fn, - carry, - scan, - length=infered_length, - reverse=reverse, - unroll=unroll, - _split_transpose=_split_transpose, - ) - ( - split_rng_state_out, - broadcast_rng_state_out, - carry_state_out, - carry_arg_out, - broadcasts, - ) = carry_out - graphdef_out, scan_args_out, scan_states_out = scan_out - - scan_args_out, scan_states_out = _transpose_tree( - (scan_args_out, scan_states_out), - (out_axes, list(state_axes.values())), - move_front=False, - ) - - if carry_state_out: - carry_state_out = State({0: carry_state_out._mapping}) - if split_rng_state_out: - split_rng_state_out = State({0: split_rng_state_out._mapping}) - if broadcast_rng_state_out: - broadcast_rng_state_out = State({0: broadcast_rng_state_out._mapping}) - - _, output_graph_nodes = ctx.merge( - graphdef_out, - *scan_states_out, - carry_state_out, - split_rng_state_out, - broadcast_rng_state_out, - ) - - carry_arg_out, scan_args_out = extract.insert_graph_nodes( - (carry_arg_out, scan_args_out), output_graph_nodes - ) - - rnglib.restore_keys(input_rng_streams) - - if scan_output: - scan_args_out = tp.cast(B, scan_args_out) - return carry_arg_out, scan_args_out - else: - return carry_arg_out - - return scan_apply_wrapper # type: ignore - - -class Scan(tp.Generic[M], LiftedModule[M]): - @staticmethod - def constructor( - module_constructor: tp.Callable[..., MA], - *, - length: int | None = None, - reverse: bool = False, - unroll: int | bool = 1, - _split_transpose: bool = False, - # extended api - in_axes: int | None | tp.Sequence[tp.Any] = 0, - in_axes_kwargs: tp.Any = 0, - out_axes: tp.Any = 0, - carry_argnum: int = 1, - # nnx specific - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - scan_output: bool = True, - ) -> tp.Callable[..., Scan[MA]]: - def _create_scan(*args, **kwargs): - return Scan( - module_constructor=module_constructor, - module_init_args=args, - module_init_kwargs=kwargs, - # base api - length=length, - reverse=reverse, - unroll=unroll, - _split_transpose=_split_transpose, - # extended api - in_axes=in_axes, - in_axes_kwargs=in_axes_kwargs, - out_axes=out_axes, - carry_argnum=carry_argnum, - # nnx specific - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - scan_output=scan_output, - ) - - return _create_scan - - def __init__( - self, - module_constructor: tp.Callable[..., M], - *, - length: int | None = None, - reverse: bool = False, - unroll: int | bool = 1, - _split_transpose: bool = False, - # extended api - in_axes: int | None | tp.Sequence[tp.Any] = 0, - in_axes_kwargs: tp.Any = 0, - out_axes: tp.Any = 0, - carry_argnum: int = 1, - # nnx specific - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - scan_output: bool = True, - # submodule args - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], - ): - self.module_constructor = module_constructor - # use Vmap to handle initialisation - vmapped_module = Vmap.constructor( - module_constructor, - in_axes=in_axes, - out_axes=None, - axis_name=None, - axis_size=length, - spmd_axis_name=None, - state_axes=state_axes, - split_rngs=split_rngs, - in_axes_kwargs=in_axes_kwargs, - transform_metadata=transform_metadata, - )(*module_init_args, **module_init_kwargs) - self.scan_module = vmapped_module.vmap_module - - @functools.partial( - scan, - length=length, - reverse=reverse, - unroll=unroll, - _split_transpose=_split_transpose, - in_axes=in_axes, - in_axes_kwargs=in_axes_kwargs, - out_axes=out_axes, - carry_argnum=carry_argnum, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - scan_output=scan_output, - ) - def scan_call(module, *args, _nnx_scan_accessor: DelayedAccessor, **kwargs): - method = _nnx_scan_accessor(module) - return method(*args, **kwargs) - - self.scan_call = scan_call - - @property - def _submodule(self) -> M: - return self.scan_module - - def _call( - self, accessor: DelayedAccessor, *args, **kwargs - ) -> tuple[tp.Any, tp.Any]: - return self.scan_call( - self._submodule, *args, _nnx_scan_accessor=accessor, **kwargs - ) diff --git a/flax/nnx/nnx/transforms/parallelization.py b/flax/nnx/nnx/transforms/parallelization.py deleted file mode 100644 index 9e86a77d32..0000000000 --- a/flax/nnx/nnx/transforms/parallelization.py +++ /dev/null @@ -1,830 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pytype: skip-file -from __future__ import annotations - -import functools -import typing as tp - -import jax -import jax.core -import jax.stages -from jax._src.tree_util import broadcast_prefix -import jax.numpy as jnp - -from flax import struct -from flax.core.frozen_dict import FrozenDict -from flax.nnx.nnx import ( - extract, - filterlib, - graph, - rnglib, - spmd, -) -from flax.nnx.nnx.module import GraphDef, Module -from flax.nnx.nnx.proxy_caller import ( - DelayedAccessor, -) -from flax.nnx.nnx.state import State -from flax.nnx.nnx.transforms.transforms import LiftedModule -from flax.typing import Leaf - -A = tp.TypeVar('A') -C = tp.TypeVar('C') -B = tp.TypeVar('B') -F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) -G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) -M = tp.TypeVar('M', bound=Module) -MA = tp.TypeVar('MA', bound=Module) -N = tp.TypeVar('N', bound=Module) -StrInt = tp.TypeVar('StrInt', str, int) -AxisName = tp.Hashable -Leaves = tp.List[Leaf] -Index = int -AxesValue = tp.Union[int, None] -SplitPattern = tp.Union[AxesValue, tuple[AxesValue, ...]] - - -# ------------------------------- -# vmap -# ------------------------------- -class _VmapForkStates(tp.NamedTuple): - split_keys: State - split_counts: State - broadcast_keys: State - broadcast_counts: State - - -def _get_axis_sizes(pytree, axes): - axes = broadcast_prefix(axes, pytree, is_leaf=lambda x: x is None) - leaves = jax.tree_util.tree_leaves(pytree) - axis_sizes = { - leaf.shape[axis] for axis, leaf in zip(axes, leaves) if axis is not None - } - return axis_sizes - - -def _fork_vmap_keys( - state: State, - split_filter: filterlib.Filter, - num_splits: int, -) -> _VmapForkStates: - split_keys, split_counts, broadcast_keys, broadcast_counts = state.split( - filterlib.All(split_filter, rnglib.RngKey), - filterlib.All(split_filter, rnglib.RngCount), - rnglib.RngKey, - rnglib.RngCount, - ) - - def split_key(key: tp.Any, count: tp.Any) -> jax.Array: - if not isinstance(key, jax.Array): - raise TypeError(f'key must be a jax.Array, got {type(key)}') - if not isinstance(count, jax.Array): - raise TypeError(f'count must be a jax.Array, got {type(count)}') - - key = jax.random.fold_in(key, count) - return jax.random.split(key, num_splits) - - split_keys_leaves, split_keys_treedef = jax.tree.flatten(split_keys) - split_counts_leaves, split_counts_treedef = jax.tree.flatten(split_counts) - - if len(split_keys_leaves) != len(split_counts_leaves): - raise ValueError( - 'split_keys and split_counts must have the same number of leaves', - f'got {len(split_keys_leaves)} and {len(split_counts_leaves)}', - ) - - split_keys_leaves = [ - split_key(key, count) - for key, count in zip(split_keys_leaves, split_counts_leaves) - ] - split_counts_leaves = [ - jnp.full((num_splits,), 0, dtype=jnp.uint32) for _ in split_counts_leaves - ] - split_keys = jax.tree.unflatten(split_keys_treedef, split_keys_leaves) - split_counts = jax.tree.unflatten(split_counts_treedef, split_counts_leaves) - - return _VmapForkStates( - split_keys, split_counts, broadcast_keys, broadcast_counts - ) - - -def _backup_vmap_keys(node: tp.Any, /): - backups: list[ - tuple[graph.PathParts, rnglib.RngStream, jax.Array, jax.Array] - ] = [] - for path, stream in graph.iter_graph(node): - if isinstance(stream, rnglib.RngStream): - backups.append((path, stream, stream.key.value, stream.count.value)) - return backups - - -def _restore_vmap_keys( - backups: list[tuple[graph.PathParts, rnglib.RngStream, jax.Array, jax.Array]], - split_rngs: filterlib.Filter, - /, -): - predicate_fn = filterlib.to_predicate(split_rngs) - for path, stream, key, count in backups: - stream.key.value = key - count_path = (*path, 'count') - if predicate_fn(count_path, stream.count.to_state()): - # restore count only if it was split - # add 1 to reflect the split - stream.count.value = count + 1 - - -def vmap_fn( - args: tuple[tp.Any, ...], - kwargs: dict[str, tp.Any], - graphdef: GraphDef[tuple[tp.Any, ...]], - split_keys: State, - split_counts: State, - broadcast_keys: State, - broadcast_counts: State, - vectorized_states: list[State], - broadcast_state: State, - transform_metadata: tp.Mapping[str, tp.Any], - state_axes_: list[tuple[filterlib.Filter, int]], - f: tp.Callable[..., tp.Any], - filters: tp.Tuple[filterlib.Filter, ...], - split_rngs: filterlib.Filter, -): - ctx = graph.current_update_context('vmap') - state_axes = dict(state_axes_) - # remove metadata axis name from Variable.sharding - if spmd.PARTITION_NAME in transform_metadata: - vectorized_states = [ - spmd.remove_axis(state, index, transform_metadata) - for state, index in zip(vectorized_states, state_axes.values()) - ] - - # merge module state - input_graph_nodes = ctx.merge( - graphdef, - *vectorized_states, - broadcast_state, - split_keys, - split_counts, - broadcast_keys, - broadcast_counts, - ) - - (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) - - out = f(*args, **kwargs) - - out, output_graph_nodes = extract.extract_graph_nodes(out) - - # split module state - ( - graphdef_out, - rng_state_out, - *vectorized_states_out, - broadcast_state_out, - ) = ctx.split( # type: ignore[misc] - (input_graph_nodes, output_graph_nodes), - rnglib.RngState, - *filters, - ) - - split_keys_out, broadcast_keys_out = rng_state_out.split(split_rngs, ...) - - broadcast_state_out = State.merge(broadcast_state_out, broadcast_keys_out) - - # add metadata axis name to Variable.sharding - if spmd.PARTITION_NAME in transform_metadata: - vectorized_states_out = [ - spmd.add_axis(state, index, transform_metadata) - for state, index in zip(vectorized_states_out, state_axes.values()) - ] - - return ( - graphdef_out, - broadcast_state_out, - vectorized_states_out, - split_keys_out, - out, - ) - - -def vmap( - f: F, - *, - in_axes: int | None | tp.Sequence[tp.Any] = 0, - out_axes: tp.Any = 0, - axis_name: AxisName | None = None, - axis_size: int | None = None, - spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), -) -> F: - vectorized_states_axes = list(state_axes.values()) - - vmapped_fn = jax.vmap( - vmap_fn, - in_axes=( - in_axes, # args - in_axes_kwargs, # kwargs - None, # graphdef - 0, # split_keys - 0, # split_counts - None, # broadcast_keys - None, # broadcast_counts - vectorized_states_axes, # vectorized_states - None, # broadcast_state - None, # transform_metadata - None, # states_axes - None, # f - None, # vectorized_states_filters - None, # split_rngs - ), - out_axes=( - None, # graphdef_out - None, # broadcast_state - vectorized_states_axes, - 0, # keys_out - out_axes, # out_axes - ), - axis_name=axis_name, - axis_size=axis_size, - spmd_axis_name=spmd_axis_name, - ) - - @functools.wraps(f) - @graph.update_context('vmap') - def vmap_wrapper(*args, **kwargs): - ctx = graph.current_update_context('vmap') - - (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( - (args, kwargs) - ) - input_rng_streams = _backup_vmap_keys(input_graph_nodes) - - # split module state - filters = (*state_axes.keys(), ...) - graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc] - input_graph_nodes, rnglib.RngState, *filters - ) - - # infer length - axis_sizes: tp.Set[int] = set() - axis_sizes.update(_get_axis_sizes(args, in_axes)) - axis_sizes.update(_get_axis_sizes(kwargs, in_axes_kwargs)) - for state, state_axis in zip(vectorized_states, state_axes.values()): - axis_sizes.update(_get_axis_sizes(state, state_axis)) - - if len(axis_sizes) > 1: - raise ValueError( - 'Inconsistent lengths between state_axes states and ' - f'arguments: {axis_sizes}' - ) - elif len(axis_sizes) == 0: - if axis_size is None: - raise ValueError( - 'Cannot infer length from state_axes states or axes_arg, ' - 'please specify `length`' - ) - _axis_size = axis_size - else: - _axis_size = axis_sizes.pop() - if axis_size is not None and axis_size != _axis_size: - raise ValueError( - f'Specified axis_size {axis_size} is not the same as the' - f' inferred length {_axis_size}' - ) - - split_keys, split_counts, broadcast_keys, broadcast_counts = ( - _fork_vmap_keys( - rng_state, - split_rngs, - _axis_size, - ) - ) - - ( - graphdef_out, - broadcast_state, - vectorized_states, - split_keys_out, - out, - ) = vmapped_fn( - args, - kwargs, - graphdef, - split_keys, - split_counts, - broadcast_keys, - broadcast_counts, - vectorized_states, - broadcast_state, - transform_metadata, - list(state_axes.items()), - f, - filters, - split_rngs, - ) - - _, output_graph_nodes = ctx.merge( - graphdef_out, - *vectorized_states, - broadcast_state, - split_keys_out, - ) - - out = extract.insert_graph_nodes(out, output_graph_nodes) - - _restore_vmap_keys(input_rng_streams, split_rngs) - - return out - - return vmap_wrapper # type: ignore - - -class Vmap(tp.Generic[M], LiftedModule[M]): - @staticmethod - def constructor( - module_constructor: tp.Callable[..., MA], - *, - in_axes: int | None | tp.Sequence[tp.Any] = 0, - out_axes: tp.Any = 0, - axis_name: AxisName | None = None, - axis_size: int | None = None, - spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - ) -> tp.Callable[..., Vmap[MA]]: - def _create_vmap(*args, **kwargs): - return Vmap( - module_constructor=module_constructor, - in_axes=in_axes, - out_axes=out_axes, - axis_size=axis_size, - axis_name=axis_name, - spmd_axis_name=spmd_axis_name, - # nnx specific - in_axes_kwargs=in_axes_kwargs, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - # submodule args - module_init_args=args, - module_init_kwargs=kwargs, - ) - - return _create_vmap - - def __init__( - self, - module_constructor: tp.Callable[..., M], - *, - in_axes: int | None | tp.Sequence[tp.Any] = 0, - out_axes: tp.Any = 0, - axis_name: AxisName | None = None, - axis_size: int | None = None, - spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - # submodule args - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], - ): - self.module_constructor = module_constructor - - @functools.partial( - vmap, - in_axes=None, - out_axes=None, - axis_name=axis_name, - axis_size=axis_size, - spmd_axis_name=spmd_axis_name, - in_axes_kwargs=None, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - ) - def vmap_init(*args, **kwargs): - return module_constructor(*args, **kwargs) - - self.vmap_module = vmap_init(*module_init_args, **module_init_kwargs) - - @functools.partial( - vmap, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name, - axis_size=axis_size, - spmd_axis_name=spmd_axis_name, - in_axes_kwargs=in_axes_kwargs, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - ) - def vmap_call(module, *args, _nnx_vmap_accessor: DelayedAccessor, **kwargs): - method = _nnx_vmap_accessor(module) - return method(*args, **kwargs) - - self.vmap_call = vmap_call - - @property - def _submodule(self) -> M: - return self.vmap_module - - def _call(self, accessor: DelayedAccessor, *args, **kwargs): - return self.vmap_call( - self._submodule, *args, _nnx_vmap_accessor=accessor, **kwargs - ) - - -# ------------------------------- -# pmap -# ------------------------------- -@struct.dataclass -class PmapInputs: - transform_metadata: tp.Mapping[str, tp.Any] = struct.field(pytree_node=False) - state_axes: tp.Mapping[filterlib.Filter, int] = struct.field( - pytree_node=False - ) - f: tp.Callable[..., tp.Any] = struct.field(pytree_node=False) - filters: tp.Tuple[filterlib.Filter, ...] = struct.field(pytree_node=False) - split_rngs: filterlib.Filter = struct.field(pytree_node=False) - - -def pmap_fn( - args: tuple[tp.Any, ...], - kwargs: dict[str, tp.Any], - graphdef: GraphDef[tuple[tp.Any, ...]], - split_keys: State, - split_counts: State, - broadcast_keys: State, - broadcast_counts: State, - vectorized_states: list[State], - broadcast_state: State, - pmap_inputs: PmapInputs, -): - transform_metadata = pmap_inputs.transform_metadata - state_axes = pmap_inputs.state_axes - f = pmap_inputs.f - filters = pmap_inputs.filters - split_rngs = pmap_inputs.split_rngs - ctx = graph.current_update_context('pmap') - # remove metadata axis name from Variable.sharding - if spmd.PARTITION_NAME in transform_metadata: - vectorized_states = [ - spmd.remove_axis(state, index, transform_metadata) - for state, index in zip(vectorized_states, state_axes.values()) - ] - - # merge module state - input_graph_nodes = ctx.merge( - graphdef, - *vectorized_states, - broadcast_state, - split_keys, - split_counts, - broadcast_keys, - broadcast_counts, - ) - - (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) - - out = f(*args, **kwargs) - - out, output_graph_nodes = extract.extract_graph_nodes(out) - - # split module state - ( - graphdef_out, - rng_state_out, - *vectorized_states_out, - broadcast_state_out, - ) = ctx.split( # type: ignore[misc] - (input_graph_nodes, output_graph_nodes), - rnglib.RngState, - *filters, - ) - - not_keys_out, split_keys_out, broadcast_keys_out = rng_state_out.split( - rnglib.NotKey, split_rngs, ... - ) - - broadcast_state_out = State.merge( - broadcast_state_out, broadcast_keys_out, not_keys_out - ) - - # add metadata axis name to Variable.sharding - if spmd.PARTITION_NAME in transform_metadata: - vectorized_states_out = [ - spmd.add_axis(state, index, transform_metadata) - for state, index in zip(vectorized_states_out, state_axes.values()) - ] - - return ( - graphdef_out, - broadcast_state_out, - vectorized_states_out, - split_keys_out, - out, - ) - - -def pmap( - f: F, - axis_name: AxisName | None = None, - *, - in_axes: tp.Any = 0, - out_axes: tp.Any = 0, - static_broadcasted_argnums: int | tp.Iterable[int] = (), - devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 - backend: str | None = None, - axis_size: int | None = None, - donate_argnums: int | tp.Iterable[int] = (), - global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), -) -> F: - if static_broadcasted_argnums: - raise NotImplementedError( - 'static_broadcasted_argnums is not yet supported in nnx.pmap' - ) - if donate_argnums != (): - raise NotImplementedError('donate_argnums is not yet supported in nnx.pmap') - - if global_arg_shapes is not None: - raise NotImplementedError( - 'global_arg_shapes is not yet supported in nnx.pmap' - ) - - vectorized_states_axes = list(state_axes.values()) - - pmapped_fn = jax.pmap( - pmap_fn, - axis_name=axis_name, - in_axes=( - in_axes, # args_axes - in_axes_kwargs, # kwargs_axes - None, # graphdef_axes - 0, # split_keys_axes - None, # split_counts_axes - None, # broadcast_keys_axes - None, # broadcast_counts_axes - vectorized_states_axes, # vectorized_states_axes - None, # broadcast_state_axes - None, # pmap_inputs_axes - ), # type: ignore - out_axes=( - None, # graphdef_out_axes - None, # broadcast_state_axes - vectorized_states_axes, - 0, # keys_axes_out - out_axes, # out_axes - ), # type: ignore - devices=devices, - backend=backend, - axis_size=axis_size, - ) - - @functools.wraps(f) - @graph.update_context('pmap') - def pmap_wrapper(*args, **kwargs): - ctx = graph.current_update_context('pmap') - - (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( - (args, kwargs) - ) - input_rng_streams = rnglib.backup_keys(input_graph_nodes) - - # split module state - filters = (*state_axes.keys(), ...) - graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc] - input_graph_nodes, rnglib.RngState, *filters - ) - - # infer length - axis_sizes: tp.Set[int] = set() - axis_sizes.update(_get_axis_sizes(args, in_axes)) - axis_sizes.update(_get_axis_sizes(kwargs, in_axes_kwargs)) - for state, state_axis in zip(vectorized_states, state_axes.values()): - axis_sizes.update(_get_axis_sizes(state, state_axis)) - - if len(axis_sizes) > 1: - raise ValueError( - 'Inconsistent lengths between state_axes states and ' - f'arguments: {axis_sizes}' - ) - elif len(axis_sizes) == 0: - if axis_size is None: - raise ValueError( - 'Cannot infer length from state_axes states or axes_arg, ' - 'please specify `length`' - ) - _axis_size = axis_size - else: - _axis_size = axis_sizes.pop() - if axis_size is not None and axis_size != _axis_size: - raise ValueError( - f'Specified axis_size {axis_size} is not the same as the' - f' inferred length {_axis_size}' - ) - - split_keys, split_counts, broadcast_keys, broadcast_counts = rnglib.fork( - rng_state, - split_rngs, - _axis_size, - ) - - ( - graphdef_out, - broadcast_state, - vectorized_states, - split_keys_out, - out, - ) = pmapped_fn( - args, - kwargs, - graphdef, - split_keys, - split_counts, - broadcast_keys, - broadcast_counts, - vectorized_states, - broadcast_state, - PmapInputs( - transform_metadata=transform_metadata, - state_axes=state_axes, - f=f, - filters=filters, - split_rngs=split_rngs, - ), - ) - - _, output_graph_nodes = ctx.merge( - graphdef_out, - *vectorized_states, - broadcast_state, - split_keys_out, - ) - - out = extract.insert_graph_nodes(out, output_graph_nodes) - - rnglib.restore_keys(input_rng_streams) - - return out - - return pmap_wrapper # type: ignore - - -class Pmap(tp.Generic[M], LiftedModule[M]): - @staticmethod - def constructor( - module_constructor: tp.Callable[..., MA], - *, - axis_name: AxisName | None = None, - in_axes: tp.Any = 0, - out_axes: tp.Any = 0, - static_broadcasted_argnums: int | tp.Iterable[int] = (), - devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 - backend: str | None = None, - axis_size: int | None = None, - donate_argnums: int | tp.Iterable[int] = (), - global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - ) -> tp.Callable[..., Pmap[MA]]: - def _create_pmap(*args, **kwargs): - return Pmap( - module_constructor=module_constructor, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - # nnx specific - in_axes_kwargs=in_axes_kwargs, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - # submodule args - module_init_args=args, - module_init_kwargs=kwargs, - ) - - return _create_pmap - - def __init__( - self, - module_constructor: tp.Callable[..., M], - *, - axis_name: AxisName | None = None, - in_axes: tp.Any = 0, - out_axes: tp.Any = 0, - static_broadcasted_argnums: int | tp.Iterable[int] = (), - devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 - backend: str | None = None, - axis_size: int | None = None, - donate_argnums: int | tp.Iterable[int] = (), - global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - # submodule args - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], - ): - self.module_constructor = module_constructor - - @functools.partial( - pmap, - axis_name=axis_name, - in_axes=None, - out_axes=None, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=(), - global_arg_shapes=None, - in_axes_kwargs=None, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - ) - def pmap_init(*args, **kwargs): - return module_constructor(*args, **kwargs) - - self.pmap_module = pmap_init(*module_init_args, **module_init_kwargs) - - @functools.partial( - pmap, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes, - in_axes_kwargs=in_axes_kwargs, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - ) - def pmap_call(module, *args, _nnx_vmap_accessor: DelayedAccessor, **kwargs): - method = _nnx_vmap_accessor(module) - return method(*args, **kwargs) - - self.pmap_call = pmap_call - - @property - def _submodule(self) -> M: - return self.pmap_module - - def _call(self, accessor: DelayedAccessor, *args, **kwargs): - return self.pmap_call( - self._submodule, *args, _nnx_vmap_accessor=accessor, **kwargs - ) diff --git a/flax/nnx/nnx/transforms/transforms.py b/flax/nnx/nnx/transforms/transforms.py index 4d25ee02ed..093d3c9308 100644 --- a/flax/nnx/nnx/transforms/transforms.py +++ b/flax/nnx/nnx/transforms/transforms.py @@ -29,18 +29,15 @@ from __future__ import annotations from abc import abstractmethod -import dataclasses import functools +import inspect import typing as tp from flax.nnx.nnx import ( extract, - filterlib, graph, - spmd, - variables, ) -from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.module import Module from flax.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, @@ -65,6 +62,52 @@ Leaves = tp.List[Leaf] Index = int +class Missing: + pass + + +MISSING = Missing() + +@tp.overload +def resolve_kwargs( + fun: tp.Callable[..., tp.Any], + args: tuple, + kwargs: dict[str, tp.Any], +) -> tuple: ... +@tp.overload +def resolve_kwargs() -> tp.Callable[[F], F]: ... +def resolve_kwargs( + fun: tp.Callable[..., tp.Any] | Missing = MISSING, + args: tuple | Missing = MISSING, + kwargs: dict[str, tp.Any] | Missing = MISSING, +) -> tuple | tp.Callable[[F], F]: + if isinstance(fun, Missing): + + def resolve_kwargs_decorator(f): + @functools.wraps(f) + def resolve_kwargs_wrapper(*args, **kwargs): + args = resolve_kwargs(f, args, kwargs) + return f(*args) + + return resolve_kwargs_wrapper + + return resolve_kwargs_decorator # type: ignore + + if isinstance(args, Missing): + raise ValueError('args must be provided') + if isinstance(kwargs, Missing): + raise ValueError('kwargs must be provided') + + if isinstance(fun, functools.partial): + # functools.partial should have an opaque signature. + fun = lambda *args, **kwargs: None + ba = inspect.signature(fun).bind(*args, **kwargs) + ba.apply_defaults() + if ba.kwargs: + raise TypeError('keyword arguments could not be resolved to positions') + else: + return ba.args + def _normalize_sequence( x: StrInt | tp.Iterable[StrInt] | None, / @@ -106,806 +149,6 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): return proxy # type: ignore -# ------------------------------- -# jit -# ------------------------------- - -UNSPECIFIED = object() - - -def _default_constrain_state(state: State) -> State: - state_spec = spmd.get_partition_spec(state) - state = jax.lax.with_sharding_constraint(state, state_spec) - return state - - -@dataclasses.dataclass(frozen=True) -class JitStaticInputs: - graphdef: GraphDef[tuple[tp.Any, ...]] - constrain_state: tp.Callable[[State], State] | None - f: tp.Callable[..., tp.Any] - - -jax.tree_util.register_static(JitStaticInputs) - - -@dataclasses.dataclass(frozen=True) -class JitStaticOutputs: - graphdef: GraphDef[tuple[tp.Any, ...]] - index_mapping: dict[Index, Index] - - -jax.tree_util.register_static(JitStaticOutputs) - - -def jit_fn( - *args, - _nnx_jit_static: JitStaticInputs, - _nnx_jit_state: State, - **kwargs, -) -> tuple[tp.Any, State, GraphDef[tuple[tp.Any, ...]]]: - ctx = graph.current_update_context('jit') - graphdef = _nnx_jit_static.graphdef - constrain_state = _nnx_jit_static.constrain_state - f = _nnx_jit_static.f - state: State = _nnx_jit_state - - if constrain_state is not None: - state = constrain_state(state) - - input_graph_nodes = ctx.merge(graphdef, state) - - (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) - - out = f(*args, **kwargs) - - out, output_graph_nodes = extract.extract_graph_nodes(out) - - graphdef, state = ctx.split((input_graph_nodes, output_graph_nodes)) - - if constrain_state is not None: - state = constrain_state(state) - - return out, state, graphdef - - -def jit( - fun: F, - *, - in_shardings: tp.Any = UNSPECIFIED, - out_shardings: tp.Any = UNSPECIFIED, - static_argnums: int | tp.Sequence[int] | None = None, - static_argnames: str | tp.Iterable[str] | None = None, - donate_argnums: int | tp.Sequence[int] | None = None, - donate_argnames: str | tp.Iterable[str] | None = None, - keep_unused: bool = False, - device: tp.Optional[jax.Device] = None, - backend: tp.Optional[str] = None, - inline: bool = False, - abstracted_axes: tp.Optional[tp.Any] = None, - # nnx specific - donate_state: bool = False, - constrain_state: bool | tp.Callable[[State], State] = False, -) -> F: - """ - Lifted version of ``jax.jit`` that can handle Modules / graph nodes as - arguments. - - Args: - fun: Function to be jitted. ``fun`` should be a pure function, as - side-effects may only be executed once. - - The arguments and return value of ``fun`` should be arrays, - scalars, or (nested) standard Python containers (tuple/list/dict) thereof. - Positional arguments indicated by ``static_argnums`` can be anything at - all, provided they are hashable and have an equality operation defined. - Static arguments are included as part of a compilation cache key, which is - why hash and equality operators must be defined. - - JAX keeps a weak reference to ``fun`` for use as a compilation cache key, - so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` - objects will already satisfy this requirement. - in_shardings: Pytree of structure matching that of arguments to ``fun``, - with all actual arguments replaced by resource assignment specifications. - It is also valid to specify a pytree prefix (e.g. one value in place of a - whole subtree), in which case the leaves get broadcast to all values in - that subtree. - - The ``in_shardings`` argument is optional. JAX will infer the shardings - from the input :py:class:`jax.Array`'s and defaults to replicating the input - if the sharding cannot be inferred. - - The valid resource assignment specifications are: - - :py:class:`Sharding`, which will decide how the value - will be partitioned. With this, using a mesh context manager is not - required. - - :py:obj:`None`, will give JAX the freedom to choose whatever sharding - it wants. - For in_shardings, JAX will mark is as replicated but this behavior - can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to - determine the output shardings. - - The size of every dimension has to be a multiple of the total number of - resources assigned to it. This is similar to pjit's in_shardings. - out_shardings: Like ``in_shardings``, but specifies resource - assignment for function outputs. This is similar to pjit's - out_shardings. - - The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` - will use GSPMD's sharding propagation to figure out what the sharding of the - output(s) should be. - static_argnums: An optional int or collection of ints that specify which - positional arguments to treat as static (compile-time constant). - Operations that only depend on static arguments will be constant-folded in - Python (during tracing), and so the corresponding argument values can be - any Python object. - - Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Calling the jitted function - with different values for these constants will trigger recompilation. - Arguments that are not arrays or containers thereof must be marked as - static. - - If neither ``static_argnums`` nor ``static_argnames`` is provided, no - arguments are treated as static. If ``static_argnums`` is not provided but - ``static_argnames`` is, or vice versa, JAX uses - :code:`inspect.signature(fun)` to find any positional arguments that - correspond to ``static_argnames`` - (or vice versa). If both ``static_argnums`` and ``static_argnames`` are - provided, ``inspect.signature`` is not used, and only actual - parameters listed in either ``static_argnums`` or ``static_argnames`` will - be treated as static. - static_argnames: An optional string or collection of strings specifying - which named arguments to treat as static (compile-time constant). See the - comment on ``static_argnums`` for details. If not - provided but ``static_argnums`` is set, the default is based on calling - ``inspect.signature(fun)`` to find corresponding named arguments. - donate_argnums: Specify which positional argument buffers are "donated" to - the computation. It is safe to donate argument buffers if you no longer - need them once the computation has finished. In some cases XLA can make - use of donated buffers to reduce the amount of memory needed to perform a - computation, for example recycling one of your input buffers to store a - result. You should not reuse buffers that you donate to a computation, JAX - will raise an error if you try to. By default, no argument buffers are - donated. - - If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no - arguments are donated. If ``donate_argnums`` is not provided but - ``donate_argnames`` is, or vice versa, JAX uses - :code:`inspect.signature(fun)` to find any positional arguments that - correspond to ``donate_argnames`` - (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are - provided, ``inspect.signature`` is not used, and only actual - parameters listed in either ``donate_argnums`` or ``donate_argnames`` will - be donated. - - For more details on buffer donation see the - `FAQ `_. - donate_argnames: An optional string or collection of strings specifying - which named arguments are donated to the computation. See the - comment on ``donate_argnums`` for details. If not - provided but ``donate_argnums`` is set, the default is based on calling - ``inspect.signature(fun)`` to find corresponding named arguments. - keep_unused: If `False` (the default), arguments that JAX determines to be - unused by `fun` *may* be dropped from resulting compiled XLA executables. - Such arguments will not be transferred to the device nor provided to the - underlying executable. If `True`, unused arguments will not be pruned. - device: This is an experimental feature and the API is likely to change. - Optional, the Device the jitted function will run on. (Available devices - can be retrieved via :py:func:`jax.devices`.) The default is inherited - from XLA's DeviceAssignment logic and is usually to use - ``jax.devices()[0]``. - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - inline: Specify whether this function should be inlined into enclosing - jaxprs (rather than being represented as an application of the xla_call - primitive with its own subjaxpr). Default False. - donate_state: Optional, bool. If True, the object state of the - graph node's state will be donated to the computation. Default False. - constrain_state: Optional, bool or callable. If True, the object - state of the graph node's state will be constrained to the partition - specified by the graph node's partition spec as computed by - :func:`nnx.spmd.get_partition_spec`. If a callable, the object State will - passed to the callable which must return the constrained object State. If - False, the object state will not be constrained. Default False. - - Returns: - A wrapped version of ``fun``, set up for just-in-time compilation. - """ - - _static_argnums = _normalize_sequence(static_argnums) - _static_argnames = _normalize_sequence(static_argnames) - _donate_argnums = _normalize_sequence(donate_argnums) - _donate_argnames = _normalize_sequence(donate_argnames) - - if donate_state: - _donate_argnames = (*_donate_argnames, '_nnx_jit_state') - - if callable(constrain_state): - _constrain_state = constrain_state - elif constrain_state: - _constrain_state = _default_constrain_state - else: - _constrain_state = None - - jit_kwargs = {} - if in_shardings is not UNSPECIFIED: - jit_kwargs['in_shardings'] = in_shardings - if out_shardings is not UNSPECIFIED: - jit_kwargs['out_shardings'] = out_shardings - - jitted_fn = jax.jit( - jit_fn, - static_argnums=_static_argnums, - static_argnames=_static_argnames, - donate_argnums=_donate_argnums, - donate_argnames=_donate_argnames, - keep_unused=keep_unused, - device=device, - backend=backend, - inline=inline, - abstracted_axes=abstracted_axes, - **jit_kwargs, - ) - - @functools.wraps(fun) - @graph.update_context('jit') - def jit_wrapper(*args, **kwargs): - ctx = graph.current_update_context('jit') - (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( - (args, kwargs) - ) - graphdef, state = ctx.split(input_graph_nodes) - out, output_state, output_graphdef = jitted_fn( - *args, - _nnx_jit_static=JitStaticInputs(graphdef, _constrain_state, fun), - _nnx_jit_state=state, - **kwargs, - ) - input_graph_nodes, output_graph_nodes = ctx.merge( - output_graphdef, output_state - ) - out = extract.insert_graph_nodes(out, output_graph_nodes) - return out - - jit_wrapper.inner = jitted_fn # type: ignore - - return jit_wrapper # type: ignore - - -class Jit(tp.Generic[M], LiftedModule[M]): - @staticmethod - def constructor( - module_constructor: tp.Callable[..., MA], - *, - in_shardings: tp.Any = UNSPECIFIED, - out_shardings: tp.Any = UNSPECIFIED, - static_argnums: int | tp.Sequence[int] | None = None, - static_argnames: str | tp.Iterable[str] | None = None, - donate_argnums: int | tp.Sequence[int] | None = None, - donate_argnames: str | tp.Iterable[str] | None = None, - keep_unused: bool = False, - device: tp.Optional[jax.Device] = None, - backend: tp.Optional[str] = None, - inline: bool = False, - abstracted_axes: tp.Optional[tp.Any] = None, - # nnx specific - donate_state: bool = False, - constrain_state: bool | tp.Callable[[State], State] = False, - ) -> tp.Callable[..., Jit[MA]]: - def _create_jit(*args, **kwargs): - return Jit( - module_constructor=module_constructor, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - donate_argnames=donate_argnames, - keep_unused=keep_unused, - device=device, - backend=backend, - inline=inline, - abstracted_axes=abstracted_axes, - # nnx specific - donate_state=donate_state, - constrain_state=constrain_state, - # submodule args - module_init_args=args, - module_init_kwargs=kwargs, - ) - - return _create_jit - - def __init__( - self, - module_constructor: tp.Callable[..., M], - *, - in_shardings: tp.Any = UNSPECIFIED, - out_shardings: tp.Any = UNSPECIFIED, - static_argnums: int | tp.Sequence[int] | None = None, - static_argnames: str | tp.Iterable[str] | None = None, - donate_argnums: int | tp.Sequence[int] | None = None, - donate_argnames: str | tp.Iterable[str] | None = None, - keep_unused: bool = False, - device: tp.Optional[jax.Device] = None, - backend: tp.Optional[str] = None, - inline: bool = False, - abstracted_axes: tp.Optional[tp.Any] = None, - # nnx specific - donate_state: bool = False, - constrain_state: bool | tp.Callable[[State], State] = False, - # submodule args - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], - ): - @functools.partial( - jit, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - donate_argnames=donate_argnames, - keep_unused=keep_unused, - device=device, - backend=backend, - inline=inline, - abstracted_axes=abstracted_axes, - donate_state=donate_state, - constrain_state=constrain_state, - ) - def jit_call_module( - module, *args, _nnx_jit_accessor: DelayedAccessor, **kwargs - ): - method = _nnx_jit_accessor(module) - return method(*args, **kwargs) - - self.jitted_fn = jit_call_module - self.module_constructor = module_constructor - self.jit_module = self.module_constructor( - *module_init_args, **module_init_kwargs - ) - - @property - def _submodule(self) -> M: - return self.jit_module - - def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: - out = self.jitted_fn( - self.jit_module, *args, _nnx_jit_accessor=accessor, **kwargs - ) - return out - - -# ------------------------------- -# grad -# ------------------------------- - - -def grad_fn(*args): - f: tp.Callable[..., tp.Any] - graphdef: GraphDef[tuple[dict[int, tp.Any], tuple[tp.Any, ...]]] - non_diff_state: State - has_aux: bool - diff_args: list[int] - ctx = graph.current_update_context('grad') - *args, f, graphdef, non_diff_state, has_aux, diff_args = args - - # rebuild diff_state from substates in args - diff_state = State({}) - for i in diff_args: - diff_state[i] = args[i] - diff_state: graph.GraphState = State({0: diff_state.raw_mapping}) - - diff_graph_nodes, input_nodes = ctx.merge( - graphdef, diff_state, non_diff_state - ) - - # add nodes to the args - for i, arg in diff_graph_nodes.items(): - args[i] = arg - - # add other nodes to the args - args = extract.insert_graph_nodes(args, input_nodes) - - out = f(*args) - - out, out_nodes = extract.extract_graph_nodes(out) - - graphdef_out, state_out = ctx.split((input_nodes, out_nodes)) - - if has_aux: - loss, aux = out - out = (loss, (graphdef_out, state_out, aux)) - else: - out = (out, (graphdef_out, state_out)) - - return out - - -def _grad_general( - f: tp.Callable[..., tp.Any], - argnums: int | tp.Sequence[int], - has_aux: bool, - holomorphic: bool, - allow_int: bool, - reduce_axes: tp.Sequence[AxisName], - wrt: filterlib.Filter, - return_value: bool, -) -> tp.Callable[..., tp.Any]: - @graph.update_context('grad') - def grad_wrapper(*args): - ctx: graph.UpdateContext = graph.current_update_context('grad') - _argnums = _normalize_sequence(argnums) - diff_graph_nodes: dict[int, tp.Any] = { - i: arg - for i, arg in enumerate(args) - if i in _argnums and graph.is_node(arg) - } - args, input_nodes = extract.extract_graph_nodes(args) - args = list(args) - - def only_diff(path: tuple, value: tp.Any) -> bool: - # diff_graph_nodes is the first element in the tuple - return path[0] == 0 - - graphdef, diff_state, non_diff_state = ctx.split( - (diff_graph_nodes, input_nodes), filterlib.All(wrt, only_diff), ... - ) # type: ignore[misc] - - # extract diff_state substates into the args - diff_args: list[int] = [] - if 0 in diff_state: - for i, diff_substate in diff_state[0].items(): # type: ignore - assert isinstance(i, int) - args[i] = diff_substate - diff_args.append(i) - transform = jax.value_and_grad if return_value else jax.grad - - _argnums = _argnums[0] if len(_argnums) == 1 else _argnums - - out = transform( - grad_fn, - argnums=_argnums, - has_aux=True, - holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes, - )(*args, f, graphdef, non_diff_state, has_aux, diff_args) - - if return_value: - if has_aux: - (loss, (graphdef_out, state_out, aux)), grads = out - out = (loss, aux), grads - else: - (loss, (graphdef_out, state_out)), grads = out - out = loss, grads - else: - if has_aux: - grads, (graphdef_out, state_out, aux) = out - out = grads, aux - else: - out, (graphdef_out, state_out) = out - - input_nodes, out_nodes = ctx.merge(graphdef_out, state_out) - - out = extract.insert_graph_nodes(out, out_nodes) - return out - - return grad_wrapper - - -def grad( - f: tp.Callable[..., tp.Any], - argnums: int | tp.Sequence[int] = 0, - has_aux: bool = False, - holomorphic: bool = False, - allow_int: bool = False, - reduce_axes: tp.Sequence[AxisName] = (), - *, - wrt: filterlib.Filter = variables.Param, -) -> tp.Callable[..., tp.Any]: - """Lifted version of ``jax.grad`` that can handle Modules / graph nodes as - arguments. - - The differentiable state of each graph node is defined by the `wrt` filter, - which by default is set to `nnx.Param`. Internally the ``State`` of - graph nodes is extracted, filtered according to `wrt` filter, and - passed to the underlying ``jax.grad`` function. The gradients - of graph nodes are of type ``State``. - - Example:: - - >>> from flax import nnx - >>> import jax.numpy as jnp - ... - >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - >>> x = jnp.ones((1, 2)) - >>> y = jnp.ones((1, 3)) - ... - >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) - >>> grad_fn = nnx.grad(loss_fn, wrt=nnx.Param) - ... - >>> grads = grad_fn(m, x, y) - >>> jax.tree_util.tree_map(jnp.shape, grads) - State({ - 'bias': VariableState( - type=Param, - value=(3,) - ), - 'kernel': VariableState( - type=Param, - value=(2, 3) - ) - }) - - Args: - fun: Function to be differentiated. Its arguments at positions specified by - ``argnums`` should be arrays, scalars, graph nodes or standard Python - containers. Argument arrays in the positions specified by ``argnums`` must - be of inexact (i.e., floating-point or complex) type. It should return a - scalar (which includes arrays with shape ``()`` but not arrays with shape - ``(1,)`` etc.) - argnums: Optional, integer or sequence of integers. Specifies which - positional argument(s) to differentiate with respect to (default 0). - has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the - first element is considered the output of the mathematical function to be - differentiated and the second element is auxiliary data. Default False. - holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be - holomorphic. If True, inputs and outputs must be complex. Default False. - allow_int: Optional, bool. Whether to allow differentiating with - respect to integer valued inputs. The gradient of an integer input will - have a trivial vector-space dtype (float0). Default False. - reduce_axes: Optional, tuple of axis names. If an axis is listed here, and - ``fun`` implicitly broadcasts a value over that axis, the backward pass - will perform a ``psum`` of the corresponding gradient. Otherwise, the - gradient will be per-example over named axes. For example, if ``'batch'`` - is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a - function that computes the total gradient while ``grad(f)`` will create - one that computes the per-example gradient. - wrt: Optional, filterlib.Filter. Filter to extract the differentiable state - of each graph node. Default is `nnx.Param`. - - """ - - return _grad_general( - f, - argnums, - has_aux, - holomorphic, - allow_int, - reduce_axes, - wrt, - return_value=False, - ) - - -def value_and_grad( - f: tp.Callable[..., tp.Any], - argnums: int | tp.Sequence[int] = 0, - has_aux: bool = False, - holomorphic: bool = False, - allow_int: bool = False, - reduce_axes: tp.Sequence[AxisName] = (), - *, - wrt: filterlib.Filter = variables.Param, -) -> tp.Callable[..., tp.Any]: - return _grad_general( - f, - argnums, - has_aux, - holomorphic, - allow_int, - reduce_axes, - wrt, - return_value=True, - ) - - -class Grad(tp.Generic[M], LiftedModule[M]): - @staticmethod - def constructor( - module_constructor: tp.Callable[..., MA], - has_aux: bool = False, - holomorphic: bool = False, - allow_int: bool = False, - reduce_axes: tp.Sequence[AxisName] = (), - return_value: bool = False, - *, - wrt: filterlib.Filter = variables.Param, - ) -> tp.Callable[..., Grad[MA]]: - def _create_grad(*args, **kwargs): - return Grad( - module_constructor=module_constructor, - wrt=wrt, - has_aux=has_aux, - holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes, - return_value=return_value, - # submodule args - module_init_args=args, - module_init_kwargs=kwargs, - ) - - return _create_grad - - def __init__( - self, - module_constructor: tp.Callable[..., M], - argnums: int | tp.Sequence[int] = 0, - has_aux: bool = False, - holomorphic: bool = False, - allow_int: bool = False, - reduce_axes: tp.Sequence[AxisName] = (), - *, - wrt: filterlib.Filter = variables.Param, - # submodule args - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], - ): - self.module_constructor = module_constructor - self.grad_module = self.module_constructor( - *module_init_args, **module_init_kwargs - ) - - @functools.partial( - grad, - argnums=argnums, - has_aux=has_aux, - holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes, - wrt=wrt, - ) - def grad_call_apply(module, *args): - *args, accessor = args - method = accessor(module) - return method(*args) - - self.grad_apply = grad_call_apply - - @property - def _submodule(self) -> M: - return self.grad_module - - def _call(self, accessor: DelayedAccessor, *args) -> tp.Any: - return self.grad_apply(self.grad_module, *args, accessor) - - -# ------------------------------- -# remat -# ------------------------------- - - -@dataclasses.dataclass -class RematOptions: - prevent_cse: bool - static_argnums: int | tuple[int, ...] - policy: tp.Callable[..., bool] | None - - def __post_init__(self): - if isinstance(self.static_argnums, int): - self.static_argnums = (self.static_argnums,) - - # add 1 as an offset to account for state parameter - self.static_argnums = tuple( - x + 1 if x >= 0 else x for x in self.static_argnums - ) - - -class Remat(tp.Generic[M], LiftedModule[M]): - @staticmethod - def constructor( - module_constructor: tp.Callable[..., MA], - prevent_cse: bool = True, - static_argnums: int | tuple[int, ...] = (), - policy: tp.Callable[..., bool] | None = None, - ) -> tp.Callable[..., Remat[MA]]: - def create_remat(*args, **kwargs): - return Remat( - module_constructor=module_constructor, - module_init_args=args, - module_init_kwargs=kwargs, - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, - ) - - return create_remat - - def __init__( - self, - *, - module_constructor: tp.Callable[..., M], - prevent_cse: bool = True, - static_argnums: int | tuple[int, ...] = (), - policy: tp.Callable[..., bool] | None = None, - # submodule args - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], - ): - self.options = RematOptions( - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, - ) - self.module_constructor = module_constructor - self.remat_module = self.module_constructor( - *module_init_args, **module_init_kwargs - ) - - @property - def _submodule(self) -> M: - return self.remat_module - - def _call(self, accessor: DelayedAccessor, *args) -> tp.Any: - def remat_apply_call(module, *args): - method = accessor(module) - return method(*args) - - return remat_apply( - self.options, - remat_apply_call, - (self.remat_module, *args), - ) - - -@graph.update_context('remat') -def remat_apply( - options: RematOptions, - f: tp.Callable[..., tp.Any], - args: tuple[tp.Any, ...], -): - ctx = graph.current_update_context('remat') - args, input_nodes = extract.extract_graph_nodes(args) - graphdef, state = ctx.split(input_nodes) - - def _remat_fn(state: State, *args): - input_nodes = ctx.merge(graphdef, state) - args = extract.insert_graph_nodes(args, input_nodes) - out = f(*args) - - out, output_nodes = extract.extract_graph_nodes(out) - new_graphdef, new_state = ctx.split((input_nodes, output_nodes)) - return (new_graphdef, new_state), out - - (new_graphdef, new_state), out = jax.checkpoint( - _remat_fn, - prevent_cse=options.prevent_cse, - static_argnums=options.static_argnums, - policy=options.policy, - )(state, *args) - - _, output_nodes = ctx.merge(new_graphdef, new_state) - out = extract.insert_graph_nodes(out, output_nodes) - - return out - - -def remat( - f: F, - *, - prevent_cse: bool = True, - static_argnums: int | tuple[int, ...] = (), - policy: tp.Callable[..., bool] | None = None, -) -> F: - options = RematOptions( - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, - ) - - @functools.wraps(f) - def remat_wrapper(*args): - return remat_apply(options, f, args) - - return remat_wrapper # type: ignore - - # ------------------------------- # eval_shape # ------------------------------- @@ -941,7 +184,7 @@ def _eval_shape_fn(state: State, *args, **kwargs): # cond # ------------------------------- -@general.split_inputs(ctx_tag='cond') +@general.split_inputs(ctxtag='cond') def cond( pred, true_fun: tp.Callable[..., A], @@ -951,8 +194,8 @@ def cond( ) -> A: return jax.lax.cond( pred, - general.merge_inputs(true_fun, ctx_tag='cond'), - general.merge_inputs(false_fun, ctx_tag='cond'), + general.merge_inputs(true_fun, ctxtag='cond'), + general.merge_inputs(false_fun, ctxtag='cond'), *operands, **kwargs, ) diff --git a/flax/nnx/nnx/traversals.py b/flax/nnx/nnx/traversals.py index eb9e5896bb..4d9c80603c 100644 --- a/flax/nnx/nnx/traversals.py +++ b/flax/nnx/nnx/traversals.py @@ -69,7 +69,7 @@ def flatten_mapping(xs: Mapping[Any, Any], Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = nnx.traversals.flatten_mapping(xs) >>> flat_xs @@ -148,7 +148,7 @@ def unflatten_mapping(xs: Any, Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, diff --git a/flax/nnx/scripts/run-all-examples.bash b/flax/nnx/scripts/run-all-examples.bash index 570e9c98e9..ab896ebd6a 100644 --- a/flax/nnx/scripts/run-all-examples.bash +++ b/flax/nnx/scripts/run-all-examples.bash @@ -1,6 +1,5 @@ set -e -cd ../../.. source .venv/bin/activate cd flax/nnx diff --git a/flax/nnx/tests/deprecated_transforms_test.py b/flax/nnx/tests/deprecated_transforms_test.py new file mode 100644 index 0000000000..eed0685f2e --- /dev/null +++ b/flax/nnx/tests/deprecated_transforms_test.py @@ -0,0 +1,370 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +from absl.testing import absltest +from flax import nnx +import jax +import jax.numpy as jnp + +from flax.nnx.nnx.transforms.deprecated import vmap, Vmap, pmap, Pmap + + +class TestVmap(absltest.TestCase): + def test_basic(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x + + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + vectorized_create_block = vmap( + create_block, state_axes={nnx.Param: 0}, axis_size=5 + ) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + module = vectorized_create_block(rngs) + + assert rngs.default.count.value == 1 + assert rngs.default.key.value == initial_key + assert not jnp.allclose( + module.linear.kernel.value[0], + module.linear.kernel.value[1], + ) + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + + x = jnp.ones((5, 1, 3)) + + def forward_block(module, x): + return module(x) + + vectorized_forward_block = vmap( + forward_block, state_axes={nnx.Param: 0}, axis_size=5 + ) + + y = vectorized_forward_block(module, x) + + assert y.shape == (5, 1, 3) + assert rngs.default.count.value == 2 + assert rngs.default.key.value == initial_key + + y2 = vectorized_forward_block(module, x) + + assert not jnp.allclose(y, y2) + + def test_basic_demo(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.dropout(nnx.relu(self.linear(x))) + + @partial(vmap, axis_size=5) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + @partial(vmap, axis_size=5) + def forward_block(module: Block, x): + return module(x) + + rngs = nnx.Rngs(0) + module = create_block(rngs) + + assert rngs.default.count.value == 1 + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + assert not jnp.allclose( + module.linear.kernel.value[0], + module.linear.kernel.value[1], + ) + + x = jnp.ones((5, 1, 3)) + + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + assert rngs.default.count.value == 2 + + y2 = forward_block(module, x) + + # dropout is working! + assert not jnp.allclose(y, y2) + + def test_replicate(self): + din = 3 + dout = 10 + + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.dropout(nnx.relu(self.linear(x))) + + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + @partial( + vmap, + state_axes={}, # replicate all state + split_rngs=True, # different rngs for each replica + ) + def forward_block(module: Block, x): + return module(x) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + module = create_block(rngs) + + assert rngs.default.count.value == 2 + assert module.linear.kernel.value.shape == (din, dout) + assert module.linear.bias.value.shape == (dout,) + + x = jnp.ones((5, 1, din)) + + y = forward_block(module, x) + + assert y.shape == (5, 1, dout) + assert rngs.default.count.value == 3 + + assert not jnp.allclose(y[0], y[1]) + + y2 = forward_block(module, x) + + # dropout is working! + assert not jnp.allclose(y, y2) + + assert rngs.default.key.value == initial_key + + def test_combinator(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.gelu(x) + return x + + MLP = Vmap.constructor(Block, state_axes={nnx.Param: 0}, axis_size=5) + + module = MLP(rngs=nnx.Rngs(0)) + + assert not jnp.allclose( + module.vmap_module.linear.kernel.value[0], + module.vmap_module.linear.kernel.value[1], + ) + assert module.vmap_module.linear.kernel.value.shape == (5, 3, 3) + assert module.vmap_module.linear.bias.value.shape == (5, 3) + + x = jnp.ones((5, 1, 3)) + y = module(x) + + assert y.shape == (5, 1, 3) + + def test_combinator_init(self): + class Block(nnx.Module): + def __init__(self, *, graphdef: str, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.graphdef = graphdef + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.gelu(x) + return x + + MLP = Vmap.constructor(Block, state_axes={nnx.Param: 0}, axis_size=5) + + module = MLP(graphdef='hello', rngs=nnx.Rngs(0)) + + assert module.vmap_module.graphdef == 'hello' + + def test_state_axes(self): + class Foo(nnx.Module): + def __init__(self): + self.param = nnx.Param(jnp.arange(5)) + + foo = Foo() + + @partial(vmap, state_axes={...: 0}) + def f(foo: Foo): + assert foo.param.value.shape == () + + f(foo) + + +class TestPmap(absltest.TestCase): + def test_basic_single(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 10, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.elu(x) + x = self.dropout(x) + return x + + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + vectorized_create_block = pmap( + create_block, state_axes={nnx.Param: 0}, axis_size=1 + ) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + module = vectorized_create_block(rngs) + + assert rngs.default.count.value == 2 + assert rngs.default.key.value == initial_key + assert module.linear.kernel.value.shape == (1, 3, 10) + assert module.linear.bias.value.shape == (1, 10) + + x = jnp.ones((1, 1, 3)) + + def forward_block(module, x): + return module(x) + + vectorized_forward_block = vmap( + forward_block, state_axes={nnx.Param: 0}, axis_size=1 + ) + + y = vectorized_forward_block(module, x) + + assert y.shape == (1, 1, 10) + assert rngs.default.count.value == 3 + assert rngs.default.key.value == initial_key + + y2 = vectorized_forward_block(module, x) + + assert not jnp.allclose(y, y2) + + def test_basic_demo_single(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.dropout(nnx.relu(self.linear(x))) + + @partial(pmap, axis_size=1) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + @partial(pmap, axis_size=1) + def forward_block(module: Block, x): + return module(x) + + rngs = nnx.Rngs(0) + module = create_block(rngs) + + assert rngs.default.count.value == 2 + assert module.linear.kernel.value.shape == (1, 3, 3) + assert module.linear.bias.value.shape == (1, 3) + + x = jnp.ones((1, 10, 3)) + + y = forward_block(module, x) + + assert y.shape == (1, 10, 3) + assert rngs.default.count.value == 3 + + y2 = forward_block(module, x) + + # dropout is working! + assert not jnp.allclose(y, y2) + + def test_replicate_single(self): + din = 3 + dout = 10 + + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.dropout(nnx.relu(self.linear(x))) + + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + @partial( + pmap, + state_axes={}, # replicate all state + split_rngs=True, # different rngs for each replica + ) + def forward_block(module: Block, x): + return module(x) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + module = create_block(rngs) + + assert rngs.default.count.value == 2 + assert module.linear.kernel.value.shape == (din, dout) + assert module.linear.bias.value.shape == (dout,) + + x = jnp.ones((1, 5, din)) + + y = forward_block(module, x) + + assert y.shape == (1, 5, dout) + assert rngs.default.count.value == 3 + + y2 = forward_block(module, x) + + # dropout is working! + assert not jnp.allclose(y, y2) + + assert rngs.default.key.value == initial_key + + def test_combinator_single(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.gelu(x) + return x + + MLP = Pmap.constructor(Block, state_axes={nnx.Param: 0}, axis_size=1) + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.pmap_module.linear.kernel.value.shape == (1, 3, 3) + assert module.pmap_module.linear.bias.value.shape == (1, 3) + + x = jnp.ones((1, 5, 3)) + y = module(x) + + assert y.shape == (1, 5, 3) diff --git a/flax/nnx/tests/experimental_test.py b/flax/nnx/tests/experimental_test.py new file mode 100644 index 0000000000..6d6f77993d --- /dev/null +++ b/flax/nnx/tests/experimental_test.py @@ -0,0 +1,21 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + + + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/flax/nnx/tests/graph_utils_test.py b/flax/nnx/tests/graph_utils_test.py index b25ebf829c..6ae95804fb 100644 --- a/flax/nnx/tests/graph_utils_test.py +++ b/flax/nnx/tests/graph_utils_test.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from functools import partial from threading import Thread +from typing import Any import jax import jax.numpy as jnp @@ -42,8 +44,8 @@ def test_flatten(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - graphdef, state, refmap = nnx.graph.flatten(g) - assert refmap is not None + refmap = nnx.graph.RefMap() + graphdef, state = nnx.graph.flatten(g, ref_index=refmap) state[0]['b'].raw_value = 2 state[3].raw_value = 4 @@ -298,7 +300,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.nodedef.subgraphs['tree'].type is nnx.graph.PytreeType + assert graphdef.subgraphs['tree'].type is nnx.graph.PytreeType m2 = nnx.merge(graphdef, state) @@ -321,14 +323,17 @@ def f(m: Foo): a = m.a b = m.b + ref_out_idx_out = nnx.graph.RefMap() graphdef: nnx.graph.GraphDef[Foo] - graphdef, state, ref_out_idx_out = nnx.graph.flatten(m) + graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): - m, idx_out_ref_in = nnx.graph.unflatten(graphdef, state) + idx_out_ref_in: dict[int, Any] = {} + m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - graphdef, state, ref_in_idx_in = nnx.graph.flatten(m) + ref_in_idx_in = nnx.graph.RefMap[Any, int]() + graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out @@ -340,7 +345,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_in_ref_out = nnx.graph.compose_mapping_reversed( ref_out_idx_out, idx_out_idx_in ) - m2, _ = nnx.graph.unflatten(graphdef, state, idxmap=idx_in_ref_out) + m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b assert m2.b is a @@ -358,14 +363,17 @@ def f(m: Foo): a = m.a b = m.b + ref_out_idx_out = nnx.graph.RefMap[Any, int]() graphdef: nnx.graph.GraphDef[Foo] - graphdef, state, ref_out_idx_out = nnx.graph.flatten(m) + graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): - m, idx_out_ref_in = nnx.graph.unflatten(graphdef, state) + idx_out_ref_in: dict[int, Any] = {} + m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - graphdef, state, ref_in_idx_in = nnx.graph.flatten(m) + ref_in_idx_in = nnx.graph.RefMap[Any, int]() + graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out @@ -377,7 +385,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_in_ref_out = nnx.graph.compose_mapping_reversed( ref_out_idx_out, idx_out_idx_in ) - m2, _ = nnx.graph.unflatten(graphdef, state, idxmap=idx_in_ref_out) + m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b assert m2.b is a @@ -392,14 +400,17 @@ def f(m: Foo): m = Foo() + ref_out_idx_out = nnx.graph.RefMap() graphdef: nnx.graph.GraphDef[Foo] - graphdef, state, ref_out_idx_out = nnx.graph.flatten(m) + graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): - m, idx_out_ref_in = nnx.graph.unflatten(graphdef, state) + idx_out_ref_in: dict[int, Any] = {} + m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - graphdef, state, ref_in_idx_in = nnx.graph.flatten(m) + ref_in_idx_in = nnx.graph.RefMap[Any, int]() + graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out @@ -411,7 +422,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_in_ref_out = nnx.graph.compose_mapping_reversed( ref_out_idx_out, idx_out_idx_in ) - m2, _ = nnx.graph.unflatten(graphdef, state, idxmap=idx_in_ref_out) + m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.ref is m2 @@ -470,6 +481,279 @@ def test_getitem(self): self.assertEqual(nodes['a'].count.value, 0) self.assertEqual(nodes['b'].count.value, 1) + def test_split_merge_context(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + with nnx.graph.split_context() as ctx: + graphdef1, state1 = ctx.split(m) + graphdef2, state2 = ctx.split(m) + + self.assertFalse(hasattr(ctx, 'ref_index')) + self.assertFalse(hasattr(ctx, 'ctxtag')) + self.assertIsInstance(graphdef1, nnx.graph.NodeDef) + self.assertIsInstance(graphdef2, nnx.graph.NodeRef) + self.assertLen(state1.flat_state(), 2) + self.assertLen(state2.flat_state(), 0) + + with nnx.graph.merge_context() as ctx: + m1 = ctx.merge(graphdef1, state1) + m2 = ctx.merge(graphdef2, state2) + + self.assertIs(m1, m2) + self.assertFalse(hasattr(ctx, 'index_ref')) + self.assertFalse(hasattr(ctx, 'ctxtag')) + + def test_split_merge_context_nested(self): + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m1 = nnx.Sequential(m2) + with nnx.graph.split_context() as ctx: + graphdef1, state1 = ctx.split(m1) + graphdef2, state2 = ctx.split(m2) + + self.assertIsInstance(graphdef1, nnx.graph.NodeDef) + self.assertIsInstance(graphdef2, nnx.graph.NodeRef) + self.assertLen(state1.flat_state(), 2) + self.assertLen(state2.flat_state(), 0) + + with nnx.graph.merge_context() as ctx: + m1 = ctx.merge(graphdef1, state1) + m2 = ctx.merge(graphdef2, state2) + + self.assertIs(m2, m1.layers[0]) + self.assertFalse(hasattr(ctx, 'index_ref')) + self.assertFalse(hasattr(ctx, 'ctxtag')) + + def test_split_merge_update_context(self): + class Foo(nnx.Module): + def __init__(self): + self.a = nnx.Param(1) + self.b = 2 + + m = Foo() + ctxtag = 'test' + + with nnx.update_context(ctxtag): + with nnx.graph.split_context(ctxtag) as ctx: + graphdef1, state1 = ctx.split(m) + graphdef2, state2 = ctx.split(m) + + self.assertFalse(hasattr(ctx, 'ref_index')) + self.assertFalse(hasattr(ctx, 'ctxtag')) + self.assertIsInstance(graphdef1, nnx.graph.NodeDef) + self.assertIsInstance(graphdef2, nnx.graph.NodeRef) + self.assertLen(state1.flat_state(), 1) + self.assertLen(state2.flat_state(), 0) + + @jax.jit + def f(graphdef1, state1, graphdef2, state2): + with nnx.graph.merge_context(ctxtag) as ctx: + m1 = ctx.merge(graphdef1, state1) + m2 = ctx.merge(graphdef2, state2) + + self.assertIs(m1, m2) + self.assertFalse(hasattr(ctx, 'index_ref')) + self.assertFalse(hasattr(ctx, 'ctxtag')) + + # swap a and b + m1.a, m1.b = m1.b, m1.a + + with nnx.graph.split_context(ctxtag) as ctx: + graphdef1, state1 = ctx.split(m1) + graphdef2, state2 = ctx.split(m2) + + return graphdef1, state1, graphdef2, state2 + + graphdef1, state1, graphdef2, state2 = f( + graphdef1, state1, graphdef2, state2 + ) + + with nnx.graph.merge_context(ctxtag) as ctx: + m1_out = ctx.merge(graphdef1, state1) + m2_out = ctx.merge(graphdef2, state2) + + self.assertIs(m, m1_out) + self.assertIs(m, m2_out) + self.assertEqual(m.a, 2) + self.assertEqual(m.b.value, 1) # type: ignore + + self.assertFalse(hasattr(ctx, 'index_ref')) + self.assertFalse(hasattr(ctx, 'ctxtag')) + + def test_to_tree_simple(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + impure_tree = (m, 1, {'b': m}) + + pure_tree = nnx.to_tree(impure_tree) + + t1 = pure_tree[0] + t2 = pure_tree[2]['b'] + + self.assertEqual(pure_tree[1], 1) + self.assertIsInstance(t1, nnx.TreeNode) + assert isinstance(t1, nnx.TreeNode) + self.assertIsInstance(t2, nnx.TreeNode) + assert isinstance(t2, nnx.TreeNode) + self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef) + self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef) + self.assertLen(t1.states[0].flat_state(), 2) + self.assertLen(t2.states[0].flat_state(), 0) + + impure_tree2 = nnx.from_tree(pure_tree) + + m1_out = impure_tree2[0] + m2_out = impure_tree2[2]['b'] + + self.assertIs(m1_out, m2_out) + self.assertEqual(impure_tree2[1], 1) + + def test_to_tree_update_context(self): + class Foo(nnx.Module): + def __init__(self): + self.a = nnx.Param(1) + self.b = 2 + + m = Foo() + impure_tree = (m, 1, {'b': m}) + ctxtag = 'test' + + with nnx.update_context(ctxtag): + pure_tree = nnx.to_tree(impure_tree, ctxtag=ctxtag) + + t1 = pure_tree[0] + t2 = pure_tree[2]['b'] + + self.assertEqual(pure_tree[1], 1) + self.assertIsInstance(t1, nnx.TreeNode) + assert isinstance(t1, nnx.TreeNode) + self.assertIsInstance(t2, nnx.TreeNode) + assert isinstance(t2, nnx.TreeNode) + self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef) + self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef) + self.assertLen(t1.states[0].flat_state(), 1) + self.assertLen(t2.states[0].flat_state(), 0) + + @jax.jit + def f(pure_tree): + impure_tree2 = nnx.from_tree(pure_tree, ctxtag=ctxtag) + m1_out = impure_tree2[0] + m2_out = impure_tree2[2]['b'] + + self.assertIs(m1_out, m2_out) + # self.assertEqual(impure_tree2[1], 1) + + # swap a and b + m1_out.a, m1_out.b = m1_out.b, m1_out.a + + pure_tree2 = nnx.to_tree(impure_tree2, ctxtag=ctxtag) + + t1 = pure_tree2[0] + t2 = pure_tree2[2]['b'] + + # self.assertEqual(pure_tree2[1], 1) + self.assertIsInstance(t1, nnx.TreeNode) + assert isinstance(t1, nnx.TreeNode) + self.assertIsInstance(t2, nnx.TreeNode) + assert isinstance(t2, nnx.TreeNode) + self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef) + self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef) + self.assertLen(t1.states[0].flat_state(), 1) + self.assertLen(t2.states[0].flat_state(), 0) + + return pure_tree2 + + pure_tree2 = f(pure_tree) + + impure_tree2 = nnx.from_tree(pure_tree2, ctxtag=ctxtag) + + m1_out = impure_tree2[0] + m2_out = impure_tree2[2]['b'] + + self.assertIs(m, m1_out) + self.assertIs(m, m2_out) + self.assertEqual(m.a, 2) + self.assertEqual(m.b.value, 1) # type: ignore + self.assertEqual(impure_tree2[1], 1) + + def test_to_tree_consistent_prefix(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + impure_tree = (m, 1, {'b': m}) + prefix = (0, None, 0) + pure_tree = nnx.to_tree(impure_tree, prefix=prefix) + + prefix = (0, None, 1) + with pytest.raises(ValueError, match='Inconsistent aliasing detected'): + nnx.to_tree(impure_tree, prefix=prefix) + + def test_simple_vmap(self): + @dataclasses.dataclass(frozen=True) + class StateAxes: + params: Any + batch_stats: Any + + class Foo(nnx.Module): + def __init__(self, a, b): + self.a = nnx.Param(a) + self.b = nnx.BatchStat(b) + + ctxtag = 'test' + with nnx.update_context(ctxtag): + m1 = Foo(a=jnp.array(0), b=jnp.arange(5)) + m2 = Foo(a=jnp.array(1), b=jnp.array(2)) + + args = (m1, m2, {'b': m1}) + m1_axes = StateAxes(None, 0) + in_axes = (m1_axes, None, {'b': m1_axes}) + jax_in_axes = jax.tree.map( + lambda x: nnx.TreeNode.from_prefixes((x.params, x.batch_stats)) + if isinstance(x, StateAxes) + else x, + in_axes, + ) + out_axes = 0 + + def split_fn(ctx: nnx.SplitContext, path, prefix, x): + if isinstance(prefix, StateAxes): + return nnx.TreeNode.from_split( + *ctx.split(x, nnx.Param, nnx.BatchStat) + ) + return nnx.TreeNode.from_split(*ctx.split(x)) + + pure_args = nnx.to_tree( + args, ctxtag=ctxtag, prefix=in_axes, split_fn=split_fn + ) + + @partial(jax.vmap, in_axes=jax_in_axes, out_axes=(jax_in_axes, out_axes)) + def f(*pure_args): + args = nnx.from_tree(pure_args, ctxtag=ctxtag) + + y = 0 + + self.assertIs(args[0], args[2]['b']) + for path, m in nnx.iter_graph(args): + if isinstance(m, Foo): + self.assertEqual(m.a.shape, ()) + self.assertEqual(m.b.shape, ()) + y += m.a + m.b + + args_out = nnx.extract.clear_non_graph_nodes(args) + + pure_args_out, y = nnx.to_tree( + (args_out, y), + prefix=(in_axes, out_axes), + ctxtag=ctxtag, + split_fn=split_fn, + ) + return pure_args_out, y + + pure_args_out, y = f(*pure_args) + + args_out, y = nnx.from_tree((pure_args_out, y), ctxtag=ctxtag) + + self.assertEqual(y.shape, (5,)) + self.assertGreater(y.sum(), 5) + self.assertIs(m1, args_out[0]) + self.assertIs(m1, args_out[2]['b']) + self.assertIs(m2, args_out[1]) + class SimpleModule(nnx.Module): pass diff --git a/flax/nnx/tests/helpers_test.py b/flax/nnx/tests/helpers_test.py index e97b5c0828..7d140d3b6f 100644 --- a/flax/nnx/tests/helpers_test.py +++ b/flax/nnx/tests/helpers_test.py @@ -69,7 +69,7 @@ def __call__(self, x: jax.Array, train: bool) -> jax.Array: assert y.shape == (1, 4) # fake gradient - grads = jax.tree_util.tree_map(jnp.ones_like, state.params) + grads = jax.tree.map(jnp.ones_like, state.params) # test apply_gradients state = state.apply_gradients(grads) diff --git a/flax/nnx/tests/integration_test.py b/flax/nnx/tests/integration_test.py index e33eebe7ed..1742e379cb 100644 --- a/flax/nnx/tests/integration_test.py +++ b/flax/nnx/tests/integration_test.py @@ -57,7 +57,7 @@ def loss_fn(model: Model): grads = loss_fn(model) nnx.update( model, - jax.tree_util.tree_map( + jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads ), ) @@ -110,7 +110,7 @@ def loss_fn(model: Model): grads = loss_fn(model) nnx.update( model, - jax.tree_util.tree_map( + jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads ), ) @@ -162,11 +162,11 @@ def loss_fn(model): return jax.numpy.mean((y_pred - y) ** 2) # compute gradient - grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + grads: nnx.State = nnx.grad(loss_fn)(model) # SGD update nnx.update( model, - jax.tree_util.tree_map( + jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads ), ) @@ -208,7 +208,7 @@ def loss_fn(params): # compute gradient grads, counts = jax.grad(loss_fn, has_aux=True)(params) # SGD update - params = jax.tree_util.tree_map(lambda w, g: w - 0.1 * g, params, grads) + params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grads) return params, counts diff --git a/flax/nnx/tests/module_test.py b/flax/nnx/tests/module_test.py index f627d32337..6d297bdafb 100644 --- a/flax/nnx/tests/module_test.py +++ b/flax/nnx/tests/module_test.py @@ -52,14 +52,14 @@ def test_tree_map(self): graphdef, state = nnx.split(m) - state = jax.tree_util.tree_map(lambda x: x + 1, state) + state = jax.tree.map(lambda x: x + 1, state) def test_split_2(self): m = nnx.Dict(a=nnx.Param(1)) graphdef, empty, some = nnx.split(m, None, ...) - some = jax.tree_util.tree_map(lambda x: x + 1, some) + some = jax.tree.map(lambda x: x + 1, some) def test_split_merge(self): m = nnx.Dict(a=nnx.Param(1)) @@ -484,7 +484,7 @@ def __init__(self): m = Foo() - m = jax.tree_util.tree_map(lambda x: x + 1, m) + m = jax.tree.map(lambda x: x + 1, m) assert m.node.value == 2 assert m.graphdef == 1 diff --git a/flax/nnx/tests/optimizer_test.py b/flax/nnx/tests/optimizer_test.py index d2bcaf609f..1d11254114 100644 --- a/flax/nnx/tests/optimizer_test.py +++ b/flax/nnx/tests/optimizer_test.py @@ -84,7 +84,7 @@ def jax_jit_train_step(graphdef, state, x, y): initial_loss = loss_fn(state.model, x, y) def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): - grads = nnx.grad(loss_fn, wrt=nnx.Param)(optimizer.model, x, y) + grads = nnx.grad(loss_fn)(optimizer.model, x, y) optimizer.update(grads) jit_decorator(nnx_jit_train_step)(state, x, y) @@ -114,7 +114,7 @@ def update(self, *, grads, **updates): # type: ignore[signature-mismatch] state = TrainState(model, tx, metrics) loss_fn = lambda model: ((model(x) - y) ** 2).mean() - grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model) + grads = nnx.grad(loss_fn)(state.model) state.update(grads=grads, values=loss_fn(state.model)) initial_loss = state.metrics.compute() state.update(grads=grads, values=loss_fn(state.model)) @@ -144,7 +144,9 @@ def test_wrt_update(self, variable): y = jnp.ones((1, 10)) loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() - grads = nnx.grad(loss_fn, wrt=variable)(state.model, x, y) + grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))( + state.model, x, y + ) initial_loss = loss_fn(model, x, y) state.update(grads=grads) self.assertTrue(loss_fn(model, x, y) < initial_loss) diff --git a/flax/nnx/tests/partitioning_test.py b/flax/nnx/tests/partitioning_test.py index e390887aed..1661342bc7 100644 --- a/flax/nnx/tests/partitioning_test.py +++ b/flax/nnx/tests/partitioning_test.py @@ -96,7 +96,7 @@ def test_update_from(self): state = nnx.split( m, )[1] - state = jax.tree_util.tree_map(lambda x: x * 2, state) + state = jax.tree.map(lambda x: x * 2, state) nnx.update(m, state) @@ -115,7 +115,7 @@ def test_update_from_with_array_leaf(self): graphdef, state = nnx.split( m, ) - state = jax.tree_util.tree_map(lambda x: x * 2, state) + state = jax.tree.map(lambda x: x * 2, state) nnx.update(m, state) diff --git a/flax/nnx/tests/rngs_test.py b/flax/nnx/tests/rngs_test.py index 400c59a0d7..5fe4025f4d 100644 --- a/flax/nnx/tests/rngs_test.py +++ b/flax/nnx/tests/rngs_test.py @@ -138,13 +138,11 @@ def __call__(self, x): assert len(rng_counts.flat_state()) == 2 # split dropout keys - split_dropout_keys = jax.tree_util.tree_map( + split_dropout_keys = jax.tree.map( lambda x: jax.random.split(x, 4), dropout_keys ) # replicate params - params = jax.tree_util.tree_map( - lambda x: jnp.stack([x] * 4, axis=0), params - ) + params = jax.tree.map(lambda x: jnp.stack([x] * 4, axis=0), params) @partial( jax.vmap, diff --git a/flax/nnx/tests/test_traversals.py b/flax/nnx/tests/test_traversals.py index 40c08e56e9..5f2641933e 100644 --- a/flax/nnx/tests/test_traversals.py +++ b/flax/nnx/tests/test_traversals.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for flax.experimental.nnx.traversal.""" +"""Tests for flax.nnx.traversal.""" from absl.testing import absltest from flax.core import freeze from flax.nnx import traversals diff --git a/flax/nnx/tests/transforms_test.py b/flax/nnx/tests/transforms_test.py index 7dcc784995..19f5e0c599 100644 --- a/flax/nnx/tests/transforms_test.py +++ b/flax/nnx/tests/transforms_test.py @@ -327,24 +327,29 @@ def test_apply_shardings(self): devices = mesh_utils.create_device_mesh((n_devices, n_devices)) mesh = jax.sharding.Mesh(devices, ('a', 'b')) - rngs = nnx.Rngs(0) - m = nnx.Linear( - 16, - 32, - rngs=rngs, - kernel_init=nnx.with_partitioning( - nnx.initializers.lecun_normal(), ('a', 'b') - ), + def sharding(*args): + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*args)) + + state_sharding = nnx.StateSharding( + { + nnx.PathContains('kernel'): sharding('a', 'b'), + nnx.PathContains('bias'): sharding('b'), + } + ) + + m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + + self.assertNotIsInstance( + m.kernel.value.sharding, jax.sharding.NamedSharding ) - @partial(nnx.jit, constrain_state=True) + @nnx.jit(in_shardings=(state_sharding,)) def constrain_object(m): pass - with mesh: - constrain_object(m) + constrain_object(m) - m.kernel.value.sharding + self.assertIsInstance(m.kernel.value.sharding, jax.sharding.NamedSharding) class TestGrad(parameterized.TestCase): @@ -419,7 +424,7 @@ def test_grad_with_type_predicate(self): d=5.0, ) - @partial(nnx.grad, wrt=nnx.BatchStat) + @nnx.grad(argnums=nnx.DiffState(0, nnx.BatchStat)) def f(m: nnx.Dict): # sum all params return m.a[0].value + m.a[1].value + m.b.value @@ -443,7 +448,7 @@ def test_multiple_inputs(self): rngs = nnx.Rngs(0) m = nnx.Linear(2, 3, rngs=rngs) loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) - grad_fn = nnx.grad(loss_fn, wrt=nnx.Param) + grad_fn = nnx.grad(loss_fn) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) grads = grad_fn(m, x, y) @@ -467,7 +472,7 @@ def test_multiple_graph_nodes(self, loss_fn, argnums): rngs = nnx.Rngs(0) m1 = nnx.Linear(2, 3, rngs=rngs) m2 = nnx.Linear(3, 3, rngs=rngs) - grad_fn = nnx.grad(loss_fn, argnums=argnums, wrt=nnx.Param) + grad_fn = nnx.grad(loss_fn, argnums=argnums) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) inputs = [x, y] @@ -484,6 +489,285 @@ def test_multiple_graph_nodes(self, loss_fn, argnums): assert 'bias' in grads_m2 assert grads_m2.bias.value.shape == (3,) + def test_multiple_args(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) + + m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) + m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) + + @nnx.grad(argnums=(m1_diffstate, m2_diffstate)) + def loss_fn(m1: nnx.Linear, m2: nnx.Linear): + return jnp.mean(m1.kernel * m2.kernel) + jnp.mean(m1.bias * m2.bias) + + grads_m1, grads_m2 = loss_fn(m1, m2) + + self.assertIn('kernel', grads_m1) + self.assertNotIn('bias', grads_m1) + self.assertNotIn('kernel', grads_m2) + self.assertIn('bias', grads_m2) + + def test_multiple_args_in_pytrees(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) + + m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) + m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) + + @nnx.grad(argnums=(m1_diffstate, m2_diffstate)) + def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): + return jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( + l1[0].bias * l2[0].bias + ) + + grads_m1, grads_m2 = loss_fn([m1], [m2]) + + self.assertIn('kernel', grads_m1[0]) + self.assertNotIn('bias', grads_m1[0]) + self.assertNotIn('kernel', grads_m2[0]) + self.assertIn('bias', grads_m2[0]) + + def test_value_and_grad_multiple_args_in_pytrees(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) + + m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) + m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) + + @nnx.value_and_grad(argnums=(m1_diffstate, m2_diffstate)) + def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): + return jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( + l1[0].bias * l2[0].bias + ) + + loss, (grads_m1, grads_m2) = loss_fn([m1], [m2]) + + self.assertEqual(loss.shape, ()) + self.assertIn('kernel', grads_m1[0]) + self.assertNotIn('bias', grads_m1[0]) + self.assertNotIn('kernel', grads_m2[0]) + self.assertIn('bias', grads_m2[0]) + + def test_value_and_grad_with_aux(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) + + m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) + m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) + + @nnx.value_and_grad(argnums=(m1_diffstate, m2_diffstate), has_aux=True) + def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): + loss = jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( + l1[0].bias * l2[0].bias + ) + l1[0].kernel.value = jnp.array(-1.0) + m3 = nnx.Linear(2, 3, rngs=nnx.Rngs(2)) + return loss, m3 + + (loss, m3), (grads_m1, grads_m2) = loss_fn([m1], [m2]) + + self.assertEqual(m1.kernel.value, -1.0) + self.assertEqual(loss.shape, ()) + self.assertIsInstance(m3, nnx.Linear) + self.assertIn('kernel', grads_m1[0]) + self.assertNotIn('bias', grads_m1[0]) + self.assertNotIn('kernel', grads_m2[0]) + self.assertIn('bias', grads_m2[0]) + +class TestCustomVJP(absltest.TestCase): + def test_basic_call(self): + m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(1, 1, rngs=nnx.Rngs(1)) + + @nnx.custom_vjp + def f(m1: nnx.Linear, m2: nnx.Linear): + y = m1.kernel * m2.kernel + m1.kernel.value = jnp.array(-1.0) + return y + + def f_fwd(m1, m2): + y = f(m1, m2) + return y, (m1, m2) + + def f_bwd(res, g): + inputs_g, out_g = g + m1, m2 = res + return inputs_g + + f.defvjp(f_fwd, f_bwd) + + y = f(m1, m2) + + self.assertEqual(m1.kernel.value, -1.0) + self.assertEqual(y.shape, (1, 1)) + + def test_jax_example(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: jax.Array + y: jax.Array + z: int + + @nnx.custom_vjp + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x) * m.y + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x), jnp.sin(m.x), m) + return y, res + + def f_bwd(res, g): + inputs_g, out_g = g + cos_x, sin_x, m = res + + self.assertIsInstance(inputs_g, tuple) + self.assertLen(inputs_g, 1) + self.assertIsInstance(inputs_g[0], nnx.State) + self.assertEqual(out_g.shape, ()) + self.assertIsInstance(m, Foo) + + m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(jnp.array(1.0), jnp.array(2.0), 0) + + grad: nnx.State = nnx.grad(f, argnums=nnx.DiffState(0, ...))(m) + + np.testing.assert_allclose(grad['x'], jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grad['y'], jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) + + def test_two_args(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: jax.Array + y: jax.Array + z: int + + @nnx.custom_vjp + def f(m1: Foo, m2: Foo): + m1.z += 1 + y = jnp.sin(m1.x) * m1.y + return y, m2 + + def f_fwd(m1: Foo, m2: Foo): + y, m2 = f(m1, m2) + res = (jnp.cos(m1.x), jnp.sin(m1.x), m1) + return (y, m2), res + + def f_bwd(res, g): + (m1_g, m2_g), (y_g, _) = g + cos_x, sin_x, m = res + + self.assertIsInstance(m1_g, nnx.State) + self.assertIsInstance(m2_g, nnx.State) + self.assertEqual(y_g.shape, ()) + self.assertIsInstance(m, Foo) + + m1_g = nnx.State(dict(x=cos_x * y_g * m.y, y=sin_x * y_g)) + m2_g = nnx.State(dict(x=m2_g['x'], y=m2_g['y'])) + + return m1_g, m2_g + + f.defvjp(f_fwd, f_bwd) + + m1 = Foo(jnp.array(1.0), jnp.array(2.0), 0) + m2 = Foo(jnp.array(3.0), jnp.array(4.0), 0) + + def loss_fn(m1, m2): + y, m2 = f(m1, m2) + return y + m2.x * m2.y + + m1_grad: nnx.State + m2_grad: nnx.State + m1_grad, m2_grad = nnx.grad( + loss_fn, argnums=(nnx.DiffState(0, ...), nnx.DiffState(1, ...)) + )(m1, m2) + + np.testing.assert_allclose(m1_grad['x'], jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(m1_grad['y'], jnp.sin(1.0)) # type: ignore + self.assertEqual(m1.z, 1) + np.testing.assert_allclose(m2_grad['x'], 4.0) # type: ignore + np.testing.assert_allclose(m2_grad['y'], 3.0) # type: ignore + + def test_non_diff_args(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: jax.Array + y: jax.Array + z: int + + @nnx.custom_vjp(nondiff_argnums=(1, 2)) + def f(m1: Foo, m2: Foo, m3): + m1.z += 1 + y = jnp.sin(m1.x) * m1.y + return y, m2 + + def f_fwd(m1: Foo, m2: Foo, m3): + y, m2 = f(m1, m2, m3) + res = (jnp.cos(m1.x), jnp.sin(m1.x), m1) + return (y, m2), res + + def f_bwd(m2, m3, res, g): + (m1_g, m2_g, m3_g), (y_g, _) = g + cos_x, sin_x, m = res + + self.assertIsInstance(m1_g, nnx.State) + self.assertIsInstance(m2_g, nnx.State) + self.assertEqual(y_g.shape, ()) + self.assertIsInstance(m, Foo) + + m1_g = nnx.State(dict(x=cos_x * y_g * m.y, y=sin_x * y_g)) + + return (m1_g,) + + f.defvjp(f_fwd, f_bwd) + + m1 = Foo(jnp.array(1.0), jnp.array(2.0), 0) + m2 = Foo(jnp.array(3.0), jnp.array(4.0), 0) + + def loss_fn(m1, m2, m3): + y, m2 = f(m1, m2, m3) + return y + m2.x * m2.y + + m1_grad: nnx.State + m1_grad = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m1, m2, m2) + + np.testing.assert_allclose(m1_grad['x'], jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(m1_grad['y'], jnp.sin(1.0)) # type: ignore + self.assertEqual(m1.z, 1) + + def test_docs_example(self): + import jax.numpy as jnp + from flax import nnx + + class Foo(nnx.Module): + def __init__(self, x, y): + self.x = nnx.Param(x) + self.y = nnx.Param(y) + + @nnx.custom_vjp + def f(m: Foo): + return jnp.sin(m.x) * m.y + + def f_fwd(m: Foo): + return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) + + def f_bwd(res, g): + ins_g, out_g = g + cos_x, sin_x, m = res + tangent_m = nnx.State(dict(x=cos_x * out_g * m.y, y=sin_x * out_g)) + return (tangent_m,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(x=jnp.array(1.0), y=jnp.array(2.0)) + grads = nnx.grad(f)(m) + class TestScan(absltest.TestCase): def test_basic(self): @@ -497,7 +781,8 @@ def __call__(self, x: jax.Array): x = nnx.gelu(x) return x - @partial(nnx.scan, state_axes={nnx.Param: 0}, length=5) + @nnx.split_rngs(splits=5) + @nnx.scan(in_axes=(nnx.Carry, 0), length=5) def create_block(_, rngs: nnx.Rngs): return None, Block(rngs=rngs) @@ -507,7 +792,7 @@ def create_block(_, rngs: nnx.Rngs): assert module.linear.bias.value.shape == (5, 3) # assert module.node.value.shape == (2,) - @partial(nnx.scan, in_axes=None, state_axes={nnx.Param: 0}, length=5) + @nnx.scan(in_axes=(nnx.Carry, 0, None), length=5) def forward_block(_, block: Block, x: jax.Array): return None, block(x) @@ -576,28 +861,26 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) def test_out_axes(self): - class Block(nnx.Module): - def __init__(self, *, rngs: nnx.Rngs): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) - self.node = nnx.Variable(jnp.ones((2,))) + self.node = nnx.BatchStat(jnp.ones((2,))) + @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=(1, 2)) def __call__(self, x: jax.Array): x = self.linear(x) x = nnx.gelu(x) return x, (x, x) - MLP = nnx.Scan.constructor( - Block, - state_axes={nnx.Param: 0}, - length=5, - out_axes=(1, 2), - ) - module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) - assert module.scan_module.linear.bias.value.shape == (5, 3) - assert module.scan_module.node.value.shape == (2,) + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + assert module.node.value.shape == (2,) x = jnp.ones((1, 3)) c, (y1, y2) = module(x) @@ -607,6 +890,39 @@ def __call__(self, x: jax.Array): assert y2.shape == (1, 3, 5) def test_in_axes(self): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes)) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + @nnx.scan(in_axes=(state_axes, nnx.Carry, 0)) + def __call__( + self, x: jax.Array, a: jax.Array + ) -> tp.Tuple[jax.Array, None]: + assert x.shape == a.shape + x = x + a + x = self.linear(x) + x = nnx.gelu(x) + return x, None + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + assert module.node.value.shape == (2,) + + x = jnp.ones((1, 3)) + a = jnp.ones((5, 1, 3)) + y, out = module(x, a) + + assert y.shape == (1, 3) + assert out is None + + def test_in_axes_combinator(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -641,6 +957,42 @@ def __call__( assert out is None def test_in_axes_broadcast(self): + test = self + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes)) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.BatchStat(jnp.ones((2,))) + + @nnx.scan(in_axes=(state_axes, nnx.Carry, 0, None)) + def __call__( + self, x: jax.Array, a: jax.Array, b: jax.Array + ) -> tp.Tuple[jax.Array, None]: + test.assertEqual(x.shape, a.shape) + test.assertEqual(x.shape, b.shape) + x = x + a + b + x = self.linear(x) + x = nnx.gelu(x) + return x, None + + module = MLP(rngs=nnx.Rngs(0)) + + self.assertEqual(module.linear.kernel.value.shape, (5, 3, 3)) + self.assertEqual(module.linear.bias.value.shape, (5, 3)) + self.assertEqual(module.node.value.shape, (2,)) + + x = jnp.ones((1, 3)) + a = jnp.ones((5, 1, 3)) + b = jnp.ones((1, 3)) + y, out = module(x, a, b) + + self.assertEqual(y.shape, (1, 3)) + self.assertIsNone(out) + + def test_in_axes_broadcast_combinator(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -649,8 +1001,7 @@ def __init__(self, *, rngs: nnx.Rngs): def __call__( self, x: jax.Array, a: jax.Array, b: jax.Array ) -> tp.Tuple[jax.Array, None]: - assert x.shape == a.shape - assert x.shape == b.shape + assert x.shape == a.shape and x.shape == b.shape x = x + a + b x = self.linear(x) x = nnx.gelu(x) @@ -678,6 +1029,40 @@ def __call__( assert out is None def test_complex(self): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes)) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + @nnx.split_rngs(splits=5) + @nnx.scan(in_axes=(state_axes, nnx.Carry)) + def __call__(self, x: jax.Array): + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + + module = MLP(rngs=nnx.Rngs(0)) + module.set_attributes(deterministic=False, use_running_average=False) + + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + assert module.node.value.shape == (2,) + + x = jnp.ones((1, 3)) + y, _ = module(x) + + assert y.shape == (1, 3) + + def test_complex_combinator(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -709,17 +1094,50 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: assert y.shape == (1, 3) def test_complex_broadcast_dropout(self): + state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5, only='params') + @nnx.vmap(in_axes=(state_axes, state_axes)) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + @nnx.split_rngs(splits=5, only='params') + @nnx.scan(in_axes=(state_axes, nnx.Carry)) + def __call__(self, x: jax.Array): + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + module = MLP(rngs=nnx.Rngs(params=0, dropout=1)) + module.set_attributes(deterministic=False, use_running_average=False) + + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + assert module.node.value.shape == (2,) + + x = jnp.ones((1, 3)) + y, _ = module(x) + + assert y.shape == (1, 3) + + def test_complex_broadcast_dropout_combinator(self): class Block(nnx.Module): - def __init__(self, *, rngs: nnx.Rngs): + def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) - self.dropout = nnx.Dropout(0.5) + self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: + def __call__(self, x: jax.Array) -> jax.Array: x = self.linear(x) x = self.bn(x) - x = self.dropout(x, rngs=rngs) + x = self.dropout(x) x = nnx.gelu(x) return x @@ -728,11 +1146,11 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: state_axes={nnx.Param: 0}, length=5, # params is split, dropout is broadcast - split_rngs=['dropout'], + split_rngs=['params'], scan_output=False, ) - module = MLP(rngs=nnx.Rngs(0)) + module = MLP(nnx.Rngs(params=0, dropout=1)) module.set_attributes(deterministic=False, use_running_average=False) assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) @@ -740,36 +1158,29 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) - y = module(x, rngs=nnx.Rngs(1)) + y = module(x) assert y.shape == (1, 3) def test_complex_decorator(self): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + class Block(nnx.Module): - @partial( - nnx.vmap, - state_axes={nnx.Param: 0}, - axis_size=5, - ) - def __init__(self, *, rngs: nnx.Rngs): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + def __init__(self, rngs: nnx.Rngs): self.d = 3 self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) - self.dropout = nnx.Dropout(0.5) + self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @partial( - nnx.scan, - state_axes={nnx.Param: 0}, - length=5, - carry_argnum=1, - ) - def __call__( - self, x: jax.Array, _, *, rngs: nnx.Rngs - ) -> tp.Tuple[jax.Array, None]: + @nnx.split_rngs(splits=5) + @nnx.scan(in_axes=(state_axes, nnx.Carry)) + def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) - x = self.dropout(x, rngs=rngs) + x = self.dropout(x) x = nnx.gelu(x) return x, None @@ -782,12 +1193,64 @@ def __call__( assert module.node.value.shape == (2,) x = jnp.ones((1, 3)) - y, out = module(x, None, rngs=nnx.Rngs(dropout=1)) + y, out = module(x) assert y.shape == (1, 3) assert out is None def test_scan_with_sharding(self): + test = self + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + transform_metadata = {nnx.PARTITION_NAME: 'layers'} + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5) + @nnx.vmap( + in_axes=(state_axes, state_axes), transform_metadata=transform_metadata + ) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), sharding=('din', 'dout') + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros_init(), sharding=('dout',) + ), + rngs=rngs, + ) + + @nnx.scan( + in_axes=(state_axes, nnx.Carry), transform_metadata=transform_metadata + ) + def __call__(self, x: jax.Array): + x = self.linear(x) + # test sharding layer axes is not present inside scan + test.assertEqual(self.linear.kernel.shape, (3, 3)) + test.assertEqual(self.linear.kernel.sharding, ('din', 'dout')) + test.assertEqual(self.linear.bias.shape, (3,)) + test.assertEqual(self.linear.bias.sharding, ('dout',)) + return x, None + + m = MLP(rngs=nnx.Rngs(0)) + + # test sharding layers axes is set + self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.bias.shape, (5, 3)) + self.assertEqual(m.linear.bias.sharding, ('layers', 'dout')) + + x = jnp.ones((1, 3)) + y, out = m(x) + + # test sharding axes is preserved + self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.bias.shape, (5, 3)) + self.assertEqual(m.linear.bias.sharding, ('layers', 'dout')) + + def test_scan_with_sharding_decorator(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear( @@ -827,21 +1290,10 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: # test sharding layers axes is set state = nnx.state(m) - assert state.scan_module.linear.kernel.value.shape == ( - 5, - 3, - 3, - ) - assert state.scan_module.linear.kernel.sharding == ( - 'layers', - 'din', - 'dout', - ) + assert state.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert state.scan_module.linear.kernel.sharding == ('layers', 'din', 'dout') assert state.scan_module.linear.bias.value.shape == (5, 3) - assert state.scan_module.linear.bias.sharding == ( - 'layers', - 'dout', - ) + assert state.scan_module.linear.bias.sharding == ('layers', 'dout') x = jnp.ones((1, 3)) y, out = m(x, None) @@ -849,16 +1301,9 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: # test sharding axes is preserved state = nnx.state(m) assert state.scan_module.linear.kernel.value.shape == (5, 3, 3) - assert state.scan_module.linear.kernel.sharding == ( - 'layers', - 'din', - 'dout', - ) + assert state.scan_module.linear.kernel.sharding == ('layers', 'din', 'dout') assert state.scan_module.linear.bias.value.shape == (5, 3) - assert state.scan_module.linear.bias.sharding == ( - 'layers', - 'dout', - ) + assert state.scan_module.linear.bias.sharding == ('layers', 'dout') def test_type_error_less_than_one_args(self): class Block(nnx.Module): @@ -908,14 +1353,15 @@ def test_cache_tracing_object(self): @dataclasses.dataclass class Foo(nnx.Object): - @partial(nnx.vmap, axis_size=5) - def __init__(self, *, rngs: nnx.Rngs): + @nnx.split_rngs(splits=5) + @nnx.vmap(axis_size=5) + def __init__(self, rngs: nnx.Rngs): self.x = nnx.Param(jax.random.normal(rngs(), shape=(3,))) foo = Foo(rngs=nnx.Rngs(0)) assert foo.x.value.shape == (5, 3) - @nnx.scan + @nnx.scan(in_axes=(nnx.Carry, 0, 0)) def f(count, x, foo): nonlocal n n += 1 @@ -932,9 +1378,11 @@ def f(count, x, foo): assert count == 10 def test_scan_broadcast_keys(self): - rngs = nnx.Rngs(params=0, dropout=1) + params_key = jax.random.split(jax.random.key(0), 3) + rngs = nnx.Rngs(params=params_key, dropout=1) + state_axes = nnx.StateAxes({'params': 0, ...: None}) - @partial(nnx.scan, split_rngs='params', length=3) + @nnx.scan(in_axes=(nnx.Carry, state_axes), length=3) def f(_, rngs: nnx.Rngs): param_key = rngs.params() dropout_key = rngs.dropout() @@ -960,7 +1408,7 @@ def test_basic_remat(self): def test_remat_decorator(self): class RematLinear(nnx.Module): - @partial(nnx.remat, static_argnums=(1, 2)) + @nnx.remat(static_argnums=(1, 2)) def __init__(self, din: int, dout: int, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) @@ -1003,38 +1451,57 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: assert y.shape == (1, 3) def test_remat_with_scan_decorator(self): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + class ScanLinear(nnx.Module): - @partial( - nnx.vmap, - state_axes={nnx.Param: 0}, - axis_size=5, - ) - def __init__(self, *, rngs: nnx.Rngs): + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) - @partial( - nnx.scan, - in_axes=None, - state_axes={nnx.Param: 0}, - length=5, - carry_argnum=1, - ) + @nnx.scan(in_axes=(state_axes, nnx.Carry)) @nnx.remat - def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: x = self.linear(x) return x, None - m = ScanLinear(rngs=nnx.Rngs(0)) + m = ScanLinear(nnx.Rngs(0)) assert m.linear.kernel.value.shape == (5, 3, 3) assert m.linear.bias.value.shape == (5, 3) - y, _ = m(jnp.ones((1, 3)), None) + y, _ = m(jnp.ones((1, 3))) assert y.shape == (1, 3) class TestVmap(absltest.TestCase): def test_basic(self): + @partial(nnx.vmap, in_axes=0, out_axes=0, axis_size=5) + def create_block(rngs: nnx.Rngs): + return nnx.Linear(2, 3, rngs=rngs) + + rngs = nnx.Rngs(0) + backups = nnx.split_rngs(rngs, splits=5) + + block = create_block(rngs) + nnx.restore_rngs(backups) + + self.assertEqual(block.kernel.value.shape, (5, 2, 3)) + self.assertEqual(rngs.default.count.value, 1) + + @partial(nnx.vmap, in_axes=(0, 1), out_axes=1) + def forward(block: nnx.Linear, x): + self.assertEqual(block.kernel.value.shape, (2, 3)) + self.assertEqual(block.bias.value.shape, (3,)) + self.assertEqual(x.shape, (2,)) + return block(x) + + x = jax.random.uniform(rngs(), (2, 5)) + y = forward(block, x) + + self.assertEqual(y.shape, (3, 5)) + + def test_state_axes(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -1046,16 +1513,20 @@ def __call__(self, x: jax.Array) -> jax.Array: x = self.dropout(x) return x + @nnx.vmap( + in_axes=0, + out_axes=nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), + ) def create_block(rngs: nnx.Rngs): + rngs = nnx.clone(rngs) return Block(rngs) - vectorized_create_block = nnx.vmap( - create_block, state_axes={nnx.Param: 0}, axis_size=5 - ) - rngs = nnx.Rngs(0) initial_key = rngs.default.key.value - module = vectorized_create_block(rngs) + + backups = nnx.split_rngs(rngs, splits=5) + module = create_block(rngs) + nnx.restore_rngs(backups) assert rngs.default.count.value == 1 assert rngs.default.key.value == initial_key @@ -1068,63 +1539,236 @@ def create_block(rngs: nnx.Rngs): x = jnp.ones((5, 1, 3)) + @nnx.vmap( + in_axes=(nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), 0), + ) def forward_block(module, x): return module(x) - vectorized_forward_block = nnx.vmap( - forward_block, state_axes={nnx.Param: 0}, axis_size=5 - ) - - y = vectorized_forward_block(module, x) + backups = nnx.split_rngs(rngs, splits=5) + y = forward_block(module, x) + nnx.restore_rngs(backups) assert y.shape == (5, 1, 3) assert rngs.default.count.value == 2 assert rngs.default.key.value == initial_key - y2 = vectorized_forward_block(module, x) + y2 = forward_block(module, x) assert not jnp.allclose(y, y2) - def test_basic_demo(self): + def test_split_rngs_context_manager(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: - return self.dropout(nnx.relu(self.linear(x))) + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x - @partial(nnx.vmap, axis_size=5) + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) def create_block(rngs: nnx.Rngs): return Block(rngs) - @partial(nnx.vmap, axis_size=5) - def forward_block(module: Block, x): + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + + with nnx.split_rngs(rngs, splits=5): + module = create_block(rngs) + + assert rngs.default.count.value == 1 + assert rngs.default.key.value == initial_key + assert not jnp.allclose( + module.linear.kernel.value[0], + module.linear.kernel.value[1], + ) + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + + x = jnp.ones((5, 1, 3)) + + @nnx.vmap(in_axes=(state_axes, 0)) + def forward_block(module, x): return module(x) + with nnx.split_rngs(module, splits=5): + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + assert rngs.default.count.value == 2 + assert rngs.default.key.value == initial_key + + with nnx.split_rngs(module, splits=5): + y2 = forward_block(module, x) + + assert not jnp.allclose(y, y2) + + def test_split_rngs_decorator(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x + + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + module = create_block(rngs) assert rngs.default.count.value == 1 - assert module.linear.kernel.value.shape == (5, 3, 3) - assert module.linear.bias.value.shape == (5, 3) + assert rngs.default.key.value == initial_key assert not jnp.allclose( module.linear.kernel.value[0], module.linear.kernel.value[1], ) + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) x = jnp.ones((5, 1, 3)) + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(state_axes, 0)) + def forward_block(module, x): + self.assertEqual(x.shape, (1, 3)) + return module(x) + y = forward_block(module, x) assert y.shape == (5, 1, 3) assert rngs.default.count.value == 2 + assert rngs.default.key.value == initial_key y2 = forward_block(module, x) - # dropout is working! assert not jnp.allclose(y, y2) + def test_state_axes_simple(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) + + @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + rngs = nnx.Rngs(params=0, dropout=1) + nnx.split_rngs(rngs, splits=5, only='dropout') + + module = create_block(rngs) + + assert module.linear.kernel.value.shape == (2, 3) + assert module.bn.scale.value.shape == (3,) + assert module.bn.mean.value.shape == (5, 3) + + @nnx.vmap(in_axes=(state_axes, 0), out_axes=0) + def forward_block(module, x): + return module(x) + + x = jnp.ones((5, 1, 2)) + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + + def test_split_rngs_decorator_simple(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) + + @nnx.split_rngs(splits=5, only='dropout') + @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + rngs = nnx.Rngs(params=0, dropout=1) + + module = create_block(rngs) + + assert module.linear.kernel.value.shape == (2, 3) + assert module.bn.scale.value.shape == (3,) + assert module.bn.mean.value.shape == (5, 3) + assert module.dropout.rngs is not None + self.assertEqual(module.dropout.rngs.params.key.shape, ()) + self.assertEqual(module.dropout.rngs.dropout.key.shape, ()) + + @nnx.split_rngs(splits=5, only='dropout') + @nnx.vmap(in_axes=(state_axes, 0), out_axes=0) + def forward_block(module: Block, x): + assert module.dropout.rngs is not None + self.assertEqual(module.dropout.rngs.params.key.shape, ()) + self.assertEqual(module.dropout.rngs.dropout.key.shape, ()) + return module(x) + + x = jnp.ones((5, 1, 2)) + y = forward_block(module, x) + + assert module.dropout.rngs is not None + self.assertEqual(module.dropout.rngs.params.key.shape, ()) + self.assertEqual(module.dropout.rngs.dropout.key.shape, ()) + assert y.shape == (5, 1, 3) + + def test_state_axes_super_simple(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + @nnx.vmap(in_axes=0, out_axes=0) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + rngs = nnx.Rngs(0) + nnx.split_rngs(rngs, splits=5) + + module = create_block(rngs) + + assert module.linear.kernel.value.shape == (5, 2, 3) + assert module.bn.scale.value.shape == (5, 3) + assert module.bn.mean.value.shape == (5, 3) + + @nnx.vmap(in_axes=(0, 0), out_axes=0) + def forward_block(module, x): + return module(x) + + x = jnp.ones((5, 1, 2)) + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + def test_replicate(self): din = 3 dout = 10 @@ -1140,11 +1784,10 @@ def __call__(self, x: jax.Array) -> jax.Array: def create_block(rngs: nnx.Rngs): return Block(rngs) - @partial( - nnx.vmap, - state_axes={}, # replicate all state - split_rngs=True, # different rngs for each replica - ) + state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None}) + + @nnx.split_rngs(splits=5) + @partial(nnx.vmap, in_axes=(state_axes, 0), out_axes=0) def forward_block(module: Block, x): return module(x) @@ -1172,63 +1815,183 @@ def forward_block(module: Block, x): assert rngs.default.key.value == initial_key - def test_combinator(self): - class Block(nnx.Module): - def __init__(self, *, rngs: nnx.Rngs): - self.linear = nnx.Linear(3, 3, rngs=rngs) + def test_consistent_aliasing_inputs(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.zeros((5, 5)) - def __call__(self, x: jax.Array) -> jax.Array: - x = self.linear(x) - x = nnx.gelu(x) - return x + m = Foo() - MLP = nnx.Vmap.constructor(Block, state_axes={nnx.Param: 0}, axis_size=5) + @nnx.vmap(in_axes=(0, 1)) + def f(m1, m2): + pass - module = MLP(rngs=nnx.Rngs(0)) + with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): + f(m, m) - assert not jnp.allclose( - module.vmap_module.linear.kernel.value[0], - module.vmap_module.linear.kernel.value[1], - ) - assert module.vmap_module.linear.kernel.value.shape == (5, 3, 3) - assert module.vmap_module.linear.bias.value.shape == (5, 3) + def test_consistent_aliasing_input_output(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.zeros((2, 3)) - x = jnp.ones((5, 1, 3)) - y = module(x) + m = Foo() - assert y.shape == (5, 1, 3) + @partial(nnx.vmap, in_axes=0, out_axes=1) + def f(m): + return m - def test_combinator_init(self): - class Block(nnx.Module): - def __init__(self, *, graphdef: str, rngs: nnx.Rngs): - self.linear = nnx.Linear(3, 3, rngs=rngs) - self.graphdef = graphdef + with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): + m2 = f(m) - def __call__(self, x: jax.Array) -> jax.Array: - x = self.linear(x) - x = nnx.gelu(x) - return x + def test_consistent_aliasing_shared(self): + class Shared(nnx.Module): + def __init__(self): + self.a = jnp.zeros((3, 3)) - MLP = nnx.Vmap.constructor(Block, state_axes={nnx.Param: 0}, axis_size=5) + class Foo(nnx.Module): + def __init__(self, shared: Shared): + self.a = shared - module = MLP(graphdef='hello', rngs=nnx.Rngs(0)) + shared = Shared() + m1 = Foo(shared) + m2 = Foo(shared) - assert module.vmap_module.graphdef == 'hello' + @partial(nnx.vmap, in_axes=(0, 1)) + def f(m1, m2): + pass - def test_state_axes(self): + with self.assertRaisesRegex( + ValueError, + r'Inconsistent aliasing detected([\s\S]*)Shared([\s\S]*)a: 0([\s\S]*)a: 1', + ): + f(m1, m2) + @pytest.mark.skip(reason='Enable once jax#19586 resolved') + def test_captured_module_in_return_error(self): class Foo(nnx.Module): + def __init__(self): + self.a = jnp.zeros((5, 5)) + + m = Foo() + + @nnx.vmap(in_axes=0, out_axes=0) + def f(x): + return x, m + + with self.assertRaisesRegex( + ValueError, + r'Trying to extract graph node from different trace level.*Foo', + ): + x = jnp.zeros((5,)) + f(x) + + def test_vmap_and_cond_passthrough(self): + class Broadcast(nnx.Variable[nnx.A]): ... + + class Vectorized(nnx.Variable[nnx.A]): ... + + class Env(nnx.Module): + def __init__(self): + self.broadcast = Broadcast(jnp.array(1)) + self.index = Vectorized(jnp.arange(8)) + self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) + + env = Env() + + @nnx.vmap(in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),)) + def f(env: Env): + self.assertEqual(env.step.shape, ()) + + def increment(env: Env): + env.step += 1 + + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond(is_even, increment, no_nothing, env) + + f(env) + + np.testing.assert_array_equal(env.step.value, [1, 0, 1, 0, 1, 0, 1, 0]) + + def test_vmap_and_cond_passthrough_error(self): + class Broadcast(nnx.Variable[nnx.A]): ... + class Vectorized(nnx.Variable[nnx.A]): ... + + class Env(nnx.Module): def __init__(self): - self.param = nnx.Param(jnp.arange(5)) + self.broadcast = Broadcast(jnp.array(1)) + self.index = Vectorized(jnp.arange(8)) + self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) + + env = Env() - foo = Foo() + @nnx.vmap(in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),)) + def f(env: Env): + self.assertEqual(env.step.shape, ()) - @partial(nnx.vmap, state_axes={...: 0}) - def f(foo: Foo): - assert foo.param.value.shape == () + def increment(env: Env): + env.step += 1 + env.broadcast += 1 - f(foo) + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond(is_even, increment, no_nothing, env) + + with self.assertRaisesRegex( + ValueError, + r"at vmap.*'broadcast'.*got axis spec None but output was batched on axis 0", + ): + f(env) + + def test_example(self): + class Model(nnx.Module): + def __init__(self, din, dout, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.bn = nnx.BatchNorm(dout, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + @nnx.vmap(in_axes=0, out_axes=0) + def initialize_ensamble(key): + rngs = nnx.Rngs(key) + return Model(2, 3, rngs=rngs) + + keys = jax.random.split(jax.random.key(0), 5) + ensamble = initialize_ensamble(keys) + + self.assertEqual(ensamble.linear.kernel.shape, (5, 2, 3)) + + @nnx.vmap(in_axes=(0, None), out_axes=0) + def forward(model, x): + return model(x) + + x = jnp.ones((4, 2)) + y = forward(ensamble, x) + self.assertEqual(y.shape, (5, 4, 3)) + + def test_example_with_vectorization(self): + class LinearEnsemble(nnx.Module): + def __init__(self, num, rngs): + self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) + + model = LinearEnsemble(5, rngs=nnx.Rngs(0)) + + @nnx.vmap(in_axes=(0, None), out_axes=0) + def forward(model, x): + self.assertEqual(model.w.shape, (2, 3)) + return jnp.dot(x, model.w.value) + + x = jnp.ones((4, 2)) + y = forward(model, x) + + self.assertEqual(y.shape, (5, 4, 3)) class TestPmap(absltest.TestCase): @@ -1245,38 +2008,36 @@ def __call__(self, x: jax.Array) -> jax.Array: x = self.dropout(x) return x + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + @nnx.split_rngs(splits=1) + @nnx.pmap(in_axes=(state_axes,), out_axes=state_axes, axis_size=1) def create_block(rngs: nnx.Rngs): return Block(rngs) - vectorized_create_block = nnx.pmap( - create_block, state_axes={nnx.Param: 0}, axis_size=1 - ) - rngs = nnx.Rngs(0) initial_key = rngs.default.key.value - module = vectorized_create_block(rngs) + module = create_block(rngs) - assert rngs.default.count.value == 2 + assert rngs.default.count.value == 1 assert rngs.default.key.value == initial_key assert module.linear.kernel.value.shape == (1, 3, 10) assert module.linear.bias.value.shape == (1, 10) x = jnp.ones((1, 1, 3)) + @nnx.split_rngs(splits=1) + @nnx.pmap(in_axes=(state_axes, 0), axis_size=1) def forward_block(module, x): return module(x) - vectorized_forward_block = nnx.vmap( - forward_block, state_axes={nnx.Param: 0}, axis_size=1 - ) - - y = vectorized_forward_block(module, x) + y = forward_block(module, x) assert y.shape == (1, 1, 10) - assert rngs.default.count.value == 3 + assert rngs.default.count.value == 2 assert rngs.default.key.value == initial_key - y2 = vectorized_forward_block(module, x) + y2 = forward_block(module, x) assert not jnp.allclose(y, y2) @@ -1289,18 +2050,20 @@ def __init__(self, rngs: nnx.Rngs): def __call__(self, x: jax.Array) -> jax.Array: return self.dropout(nnx.relu(self.linear(x))) - @partial(nnx.pmap, axis_size=1) + @nnx.split_rngs(splits=1) + @nnx.pmap(axis_size=1) def create_block(rngs: nnx.Rngs): return Block(rngs) - @partial(nnx.pmap, axis_size=1) + @nnx.split_rngs(splits=1) + @nnx.pmap(axis_size=1) def forward_block(module: Block, x): return module(x) rngs = nnx.Rngs(0) module = create_block(rngs) - assert rngs.default.count.value == 2 + assert rngs.default.count.value == 1 assert module.linear.kernel.value.shape == (1, 3, 3) assert module.linear.bias.value.shape == (1, 3) @@ -1309,7 +2072,7 @@ def forward_block(module: Block, x): y = forward_block(module, x) assert y.shape == (1, 10, 3) - assert rngs.default.count.value == 3 + assert rngs.default.count.value == 2 y2 = forward_block(module, x) @@ -1331,11 +2094,10 @@ def __call__(self, x: jax.Array) -> jax.Array: def create_block(rngs: nnx.Rngs): return Block(rngs) - @partial( - nnx.pmap, - state_axes={}, # replicate all state - split_rngs=True, # different rngs for each replica - ) + state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None}) + + @nnx.split_rngs(splits=1) + @partial(nnx.pmap, in_axes=(state_axes, 0), out_axes=0, axis_size=1) def forward_block(module: Block, x): return module(x) @@ -1425,6 +2187,122 @@ def reward_0(self: Foo): assert foo.timestep.step == 4 assert foo.timestep.reward == 0.0 + def test_cond_and_vmap(self): + class Env(nnx.Module): + def __init__(self): + self.index = jnp.arange(8) + self.step = jnp.zeros((8,), jnp.uint32) + + env = Env() + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.vmap(in_axes=(0, None), out_axes=None) + def f(env: Env, model: nnx.Linear): + self.assertEqual(env.index.shape, ()) + + def increment(env: Env): + env.step += 1 + + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond(is_even, increment, no_nothing, env) + + f(env, model) + + np.testing.assert_array_equal(env.step, [1, 0, 1, 0, 1, 0, 1, 0]) + + +class TestSplitMergeInputs(absltest.TestCase): + def test_split_inputs(self): + class StatefulLinear(nnx.Linear): + def __init__(self, din: int, dout: int, rngs: nnx.Rngs): + super().__init__(din, dout, rngs=rngs) + self.counter = jnp.array(0, jnp.uint32) + + def __call__(self, x): + self.counter += 1 + return super().__call__(x) + + model = StatefulLinear(3, 4, rngs=nnx.Rngs(0)) + + @nnx.split_inputs + @jax.jit + @nnx.merge_inputs + def forward(model, x): + return model(x) + + x = jnp.ones((2, 3)) + y = forward(model, x) + + self.assertEqual(model.counter, 1) + + def test_split_inputs_cond(self): + class Counter(nnx.Linear): + def __init__(self): + self.count = jnp.array(0, jnp.uint32) + + def increment(self): + self.count += 1 + + counter = Counter() + + @nnx.merge_inputs + def increment(counter: Counter): + counter.increment() + + @nnx.merge_inputs + def no_nothing(counter: Counter): + pass + + nnx.split_inputs(jax.lax.cond)(True, increment, no_nothing, counter) + + self.assertEqual(counter.count, 1) + + nnx.split_inputs(jax.lax.cond)(False, increment, no_nothing, counter) + + self.assertEqual(counter.count, 1) + + def test_split_inputs_vmap(self): + class EnvState(nnx.Variable[nnx.A]): + pass + + class Env(nnx.Object): + def __init__(self): + self.index = EnvState(jnp.arange(8)) + self.step = EnvState(jnp.zeros((8,), jnp.uint32)) + + env = Env() + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + # internally merge_inputs returns (args, out) + in_axes = (0, None) + out_axes = (in_axes, None) + + @nnx.split_inputs + @partial(jax.vmap, in_axes=in_axes, out_axes=out_axes) + @nnx.merge_inputs + def f(env: Env, model: nnx.Linear): + self.assertEqual(env.index.value.shape, ()) + + @nnx.merge_inputs + def increment(env: Env): + env.step.value += 1 + + @nnx.merge_inputs + def no_nothing(env: Env): + pass + + is_even = env.index.value % 2 == 0 + nnx.split_inputs(jax.lax.cond)(is_even, increment, no_nothing, env) + + f(env, model) + + np.testing.assert_array_equal( + env.step.value, np.array([1, 0, 1, 0, 1, 0, 1, 0], np.uint32) + ) + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self): diff --git a/flax/nnx/tests/variable_test.py b/flax/nnx/tests/variable_test.py index 5b3e899490..af84037856 100644 --- a/flax/nnx/tests/variable_test.py +++ b/flax/nnx/tests/variable_test.py @@ -28,7 +28,7 @@ def test_pytree(self): r1 = nnx.VariableState(nnx.Param, 1) assert r1.value == 1 - r2 = jax.tree_util.tree_map(lambda x: x + 1, r1) + r2 = jax.tree.map(lambda x: x + 1, r1) assert r1.value == 1 assert r2.value == 2 diff --git a/flax/typing.py b/flax/typing.py index aa4cc00cd3..e80f8f4ee2 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque +from functools import partial from typing import ( Any, Generic, @@ -27,6 +29,7 @@ from flax.core import FrozenDict import dataclasses +import jax.tree_util as jtu # General @@ -124,3 +127,29 @@ class Out(Generic[T]): PartitionSpecPytree = Any # pylint: disable=invalid-name Sharding = tuple[Optional[str], ...] + +A = TypeVar('A') + + +class PytreeDeque(deque[A]): + pass + + +def _pytree_deque_flatten(xs: PytreeDeque, *, with_path: bool): + if with_path: + nodes = tuple((jtu.SequenceKey(i), x) for i, x in enumerate(xs)) + return nodes, () + else: + return xs, () + + +def _pytree_deque_unflatten(_, nodes): + return PytreeDeque(nodes) + + +jtu.register_pytree_with_keys( + PytreeDeque, + partial(_pytree_deque_flatten, with_path=True), + _pytree_deque_unflatten, + flatten_func=partial(_pytree_deque_flatten, with_path=False), +)