diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 0ed7392b5a..56b3b66896 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -25,6 +25,7 @@ from .nnx import compat as compat from .nnx import traversals as traversals from .nnx import filterlib as filterlib +from .nnx import transforms as transforms from .nnx.filterlib import WithTag as WithTag from .nnx.filterlib import PathContains as PathContains from .nnx.filterlib import OfType as OfType @@ -103,6 +104,8 @@ from .nnx.rnglib import RngCount as RngCount from .nnx.rnglib import ForkStates as ForkStates from .nnx.rnglib import fork as fork +from .nnx.rnglib import split_rngs as split_rngs +from .nnx.rnglib import restore_rngs as restore_rngs from .nnx.spmd import PARTITION_NAME as PARTITION_NAME from .nnx.spmd import get_partition_spec as get_partition_spec from .nnx.spmd import get_named_sharding as get_named_sharding @@ -122,8 +125,10 @@ from .nnx.transforms.looping import Scan as Scan from .nnx.transforms.parallelization import Vmap as Vmap from .nnx.transforms.parallelization import Pmap as Pmap -from .nnx.transforms.transforms import grad as grad +from .nnx.transforms.general import split_inputs as split_inputs +from .nnx.transforms.general import merge_inputs as merge_inputs from .nnx.transforms.transforms import jit as jit +from .nnx.transforms.transforms import grad as grad from .nnx.transforms.transforms import remat as remat from .nnx.transforms.looping import scan as scan from .nnx.transforms.transforms import value_and_grad as value_and_grad @@ -131,6 +136,8 @@ from .nnx.transforms.parallelization import pmap as pmap from .nnx.transforms.transforms import eval_shape as eval_shape from .nnx.transforms.transforms import cond as cond +from .nnx.transforms.experimental import vmap as experimental_vmap +from .nnx.transforms.experimental import StateAxes as StateAxes from .nnx.variables import EMPTY as EMPTY from .nnx.variables import A as A from .nnx.variables import BatchStat as BatchStat diff --git a/flax/nnx/nnx/extract.py b/flax/nnx/nnx/extract.py new file mode 100644 index 0000000000..be99d83f00 --- /dev/null +++ b/flax/nnx/nnx/extract.py @@ -0,0 +1,202 @@ +import abc +import typing as tp + +import jax +from jax._src.tree_util import broadcast_prefix + +from flax import struct +from flax.nnx.nnx.state import State +from flax.typing import PathParts +from flax.nnx.nnx import graph + + +class Missing: + pass + + +MISSING = Missing() +A = tp.TypeVar('A') +E = tp.TypeVar('E', bound='Extractable') +Index = int + + +class Extractable(abc.ABC): + @property + @abc.abstractmethod + def index(self) -> Index: ... + + +class ExtractableStates(Extractable): + @property + @abc.abstractmethod + def states(self) -> tp.Iterable[State]: ... + + @property + @abc.abstractmethod + def graphdef(self) -> graph.GraphDef[tp.Any]: ... + + +class ExtractionIndex(struct.PyTreeNode, Extractable): + """Index of a graph node in a Pytree structure.""" + + _index: Index = struct.field(pytree_node=False) + + @property + def index(self) -> Index: + return self._index + + +@tp.overload +def extract_graph_nodes(pytree: A, /) -> tuple[A, tuple[tp.Any, ...]]: ... + + +@tp.overload +def extract_graph_nodes( + pytree: A, /, *, prefix: tp.Any +) -> tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]: ... + + +def extract_graph_nodes( + pytree: A, /, *, prefix: tp.Any = MISSING +) -> ( + tuple[A, tuple[tp.Any, ...]] + | tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]] +): + """Extracts all graph nodes from a pytree.""" + nodes = graph.RefMap[tp.Any, Index]() + node_prefixes = [] + leaves = [] + + prefix_leaves = broadcast_prefix( + prefix, + pytree, + is_leaf=lambda x: x is None, + ) + key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree) + + assert len(key_leaves) == len(prefix_leaves) + + for (keypath, leaf), prefix_leaf in zip(key_leaves, prefix_leaves): + if graph.is_graph_node(leaf): + if leaf not in nodes: + index = nodes[leaf] = len(nodes) + node_prefixes.append(prefix_leaf) + else: + index = nodes[leaf] + # check consistent aliasing + if prefix_leaf != node_prefixes[index]: + path_str = jax.tree_util.keystr(keypath) + raise ValueError( + f'Inconsistent aliasing detected. Node {type(leaf)} at path {path_str} ' + f'has different prefixes: {prefix_leaf} and {node_prefixes[index]}.' + ) + leaves.append(ExtractionIndex(index)) + else: + leaves.append(leaf) + + pytree_out = jax.tree.unflatten(treedef, leaves) + + if prefix is MISSING: + return pytree_out, tuple(nodes) + else: + return pytree_out, tuple(nodes), tuple(node_prefixes) + + +def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A: + """Inserts graph nodes into a pytree.""" + + def _maybe_insert(x): + if isinstance(x, Extractable): + return nodes[x.index] + return x + + return jax.tree_util.tree_map( + _maybe_insert, pytree, is_leaf=lambda x: isinstance(x, Extractable) + ) + + +def extract_indexes( + pytree, /, types: tuple[type[E], ...] | type[E] = Extractable +) -> tuple[E, ...]: + """Extracts all indexes from a pytree.""" + indexes: list[E] = [] + for x in jax.tree.leaves( + pytree, is_leaf=lambda x: isinstance(x, Extractable) + ): + if isinstance(x, Extractable): + if not isinstance(x, types): + raise ValueError(f'Expected Extractable of type {types}, got {type(x)}') + indexes.append(x) + return tuple(indexes) + + +def replace_indexes( + pytree: A, + replace_fn: tp.Callable[[Extractable], tp.Any], + /, + clear: bool = False, +) -> A: + def _replace_map_fn(x): + if isinstance(x, Extractable): + return replace_fn(x) + elif clear: + return None + return x + + return jax.tree_util.tree_map( + _replace_map_fn, pytree, is_leaf=lambda x: isinstance(x, Extractable) + ) + + +def merge_extractable_states( + extractable_states: tp.Sequence[ExtractableStates], / +): + if len(extractable_states) == 0: + raise ValueError('Expected at least one ExtractableStates object') + + graphdef = extractable_states[0].graphdef + flat_state = [] + + for extractable_state in extractable_states: + flat_state.extend( + ((extractable_state.index, *path), value) + for state in extractable_state.states + for path, value in state.flat_state().items() + ) + + state = State.from_flat_path(flat_state) + return graphdef, state + + +def check_consistent_aliasing( + nodes: tuple[tp.Any, ...], prefixes: tuple[tp.Any, ...] +): + node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]() + + # collect all paths and prefixes for each node + for node, prefix in zip(nodes, prefixes): + for path, value in graph.iter_graph(node): + if graph.is_graph_node(value): + if value in node_prefixes: + paths_prefixes = node_prefixes[value] + paths_prefixes.append((path, prefix)) + else: + node_prefixes[value] = [(path, prefix)] + + # check for inconsistent aliasing + node_msgs = [] + for node, paths_prefixes in node_prefixes.items(): + unique_prefixes = {prefix for _, prefix in paths_prefixes} + if len(unique_prefixes) > 1: + path_prefix_repr = '\n'.join( + f' {"/".join(map(str,path)) if path else ""}: {prefix}' + for path, prefix in paths_prefixes + ) + nodes_msg = f'Node: {type(node)}\n{path_prefix_repr}' + node_msgs.append(nodes_msg) + + if node_msgs: + raise ValueError( + 'Inconsistent aliasing detected. The following nodes have different prefixes:\n' + + '\n'.join(node_msgs) + ) diff --git a/flax/nnx/nnx/graph.py b/flax/nnx/nnx/graph.py index 50d20c94e6..e441859cc6 100644 --- a/flax/nnx/nnx/graph.py +++ b/flax/nnx/nnx/graph.py @@ -14,7 +14,6 @@ from __future__ import annotations -from collections import defaultdict import dataclasses import enum import functools @@ -61,8 +60,8 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: @dataclasses.dataclass class GraphContext(threading.local): - update_context_stacks: defaultdict[str, list[UpdateContext]] = ( - dataclasses.field(default_factory=lambda: defaultdict(list)) + update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field( + default_factory=dict ) @@ -1021,21 +1020,27 @@ class UpdateContextManager: def __enter__(self): ctx = UpdateContext(self.tag, None, None) - GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx) + if self.tag not in GRAPH_CONTEXT.update_context_stacks: + GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx] + else: + GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx) return ctx def __exit__(self, *args): - stack = GRAPH_CONTEXT.update_context_stacks[self.tag] - if not stack: + if self.tag not in GRAPH_CONTEXT.update_context_stacks: raise RuntimeError( f'No update context found for tag {self.tag!r}, this is a bug.' ) + stack = GRAPH_CONTEXT.update_context_stacks[self.tag] - ctx = GRAPH_CONTEXT.update_context_stacks[self.tag].pop() + ctx = stack.pop() # clear references ctx.refmap = None ctx.idxmap = None + if not stack: + del GRAPH_CONTEXT.update_context_stacks[self.tag] + def __call__(self, f: F) -> F: @functools.wraps(f) def update_context_manager_wrapper(*args, **kwargs): @@ -1142,10 +1147,9 @@ def update_context(tag: str): def current_update_context(tag: str) -> UpdateContext: """Returns the current active :class:`UpdateContext` for the given tag.""" - stack = GRAPH_CONTEXT.update_context_stacks[tag] - if not stack: + if tag not in GRAPH_CONTEXT.update_context_stacks: raise ValueError(f'No update context found for tag {tag!r}.') - return stack[-1] + return GRAPH_CONTEXT.update_context_stacks[tag][-1] # -------------------------------------------------------- @@ -1595,50 +1599,6 @@ class Static(tp.Generic[A]): jax.tree_util.register_static(Static) -# --------------------------------------------------------- -# insert/extract_graph_nodes API -# --------------------------------------------------------- - - -@dataclasses.dataclass(frozen=True) -class GraphNodeIndex: - """Index of a graph node in a Pytree structure.""" - - index: Index - - -jax.tree_util.register_static(GraphNodeIndex) - - -def extract_graph_nodes(pytree: A, /) -> tuple[A, tuple[tp.Any, ...]]: - """Extracts all graph nodes from a pytree.""" - nodes = RefMap[tp.Any, Index]() - - def _maybe_extract(x): - if is_graph_node(x): - if x not in nodes: - index = nodes[x] = len(nodes) - else: - index = nodes[x] - return GraphNodeIndex(index) - return x - - return jax.tree_util.tree_map(_maybe_extract, pytree), tuple(nodes) - - -def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A: - """Inserts graph nodes into a pytree.""" - - def _maybe_insert(x): - if isinstance(x, GraphNodeIndex): - return nodes[x.index] - return x - - return jax.tree_util.tree_map( - _maybe_insert, pytree, is_leaf=lambda x: isinstance(x, GraphNodeIndex) - ) - - # --------------------------------------------------------- # Pytree # --------------------------------------------------------- diff --git a/flax/nnx/nnx/rnglib.py b/flax/nnx/nnx/rnglib.py index 8c65b10978..86c9d0fe4b 100644 --- a/flax/nnx/nnx/rnglib.py +++ b/flax/nnx/nnx/rnglib.py @@ -53,15 +53,13 @@ class Missing: class RngState(Variable[jax.Array]): - pass + tag: str -class RngCount(RngState): - tag: str +class RngCount(RngState): ... -class RngKey(RngState): - tag: str +class RngKey(RngState): ... NotKey = filterlib.All(RngState, filterlib.Not(RngKey)) @@ -277,15 +275,49 @@ def split_key(key: tp.Any) -> jax.Array: return ForkStates(split_keys, split_counts, broadcast_keys, broadcast_counts) +StreamBackup = ( + tuple[RngStream, jax.Array, jax.Array] | tuple[RngStream, jax.Array] +) + + +def split_rngs( + node, + /, + num_splits: int | tuple[int | None, ...], + filter: filterlib.Filter = ..., +): + predicate = filterlib.to_predicate(filter) + _num_splits: int | tuple[int, ...] + if isinstance(num_splits, int): + _num_splits = num_splits + else: + _num_splits = tuple(x if x is not None else 1 for x in num_splits) + backups: list[StreamBackup] = [] + for path, stream in graph.iter_graph(node): + if ( + isinstance(stream, RngStream) + and predicate((*path, 'key'), stream.key) + and predicate((*path, 'count'), stream.count) + ): + key = stream() + backups.append((stream, stream.key.value, stream.count.value)) + stream.key.value = jax.random.split(key, _num_splits) + stream.count.value = jnp.zeros(stream.key.value.shape, dtype=jnp.uint32) + + return backups + def backup_keys(node: tp.Any, /): - backups: list[tuple[RngStream, jax.Array]] = [] + backups: list[StreamBackup] = [] for _, stream in graph.iter_graph(node): if isinstance(stream, RngStream): backups.append((stream, stream.key.value)) return backups -def restore_keys(backups: list[tuple[RngStream, jax.Array]], /): - for stream, key in backups: - stream.key.value = key \ No newline at end of file +def restore_rngs(backups: list[StreamBackup], /): + for backup in backups: + stream = backup[0] + stream.key.value = backup[1] # key + if len(backup) == 3: + stream.count.value = backup[2] # count \ No newline at end of file diff --git a/flax/nnx/nnx/state.py b/flax/nnx/nnx/state.py index 94f92e82f6..efb9989f35 100644 --- a/flax/nnx/nnx/state.py +++ b/flax/nnx/nnx/state.py @@ -160,7 +160,13 @@ def flat_state(self) -> FlatState[V]: return traversals.flatten_mapping(self._mapping) @classmethod - def from_flat_path(cls, flat_state: tp.Mapping[PathParts, V], /) -> State: + def from_flat_path( + cls, + flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]], + /, + ) -> State: + if not isinstance(flat_state, tp.Mapping): + flat_state = dict(flat_state) nested_state = traversals.unflatten_mapping(flat_state) return cls(nested_state) diff --git a/flax/nnx/nnx/transforms/__init__.py b/flax/nnx/nnx/transforms/__init__.py index af25fe62d9..73487a1bcd 100644 --- a/flax/nnx/nnx/transforms/__init__.py +++ b/flax/nnx/nnx/transforms/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import experimental as experimental \ No newline at end of file diff --git a/flax/nnx/nnx/transforms/experimental.py b/flax/nnx/nnx/transforms/experimental.py new file mode 100644 index 0000000000..bde1634d60 --- /dev/null +++ b/flax/nnx/nnx/transforms/experimental.py @@ -0,0 +1,419 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pytype: skip-file +from __future__ import annotations + +import dataclasses +import functools +import typing as tp + +from flax import struct +from flax.core.frozen_dict import FrozenDict +from flax.nnx.nnx import ( + extract, + filterlib, + graph, + spmd, +) +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.state import State +from flax.typing import Leaf +import jax +import jax.core +import jax.stages + +A = tp.TypeVar('A') +C = tp.TypeVar('C') +B = tp.TypeVar('B') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) +M = tp.TypeVar('M', bound=Module) +MA = tp.TypeVar('MA', bound=Module) +N = tp.TypeVar('N', bound=Module) +StrInt = tp.TypeVar('StrInt', str, int) +AxisName = tp.Hashable +Leaves = tp.List[Leaf] +Index = int + +class Missing: + pass + + +MISSING = Missing() + +# ------------------------------- +# vmap +# ------------------------------- + + +class VmapArgState(extract.ExtractionIndex, extract.ExtractableStates): + _graphdef: GraphDef[tp.Any] = struct.field(pytree_node=False) + state: State = struct.field(pytree_node=True) + filter: filterlib.Predicate = struct.field(pytree_node=False) + axis: Index | None = struct.field(pytree_node=False) + + @property + def states(self) -> tp.Iterable[State]: + yield self.state + + @property + def graphdef(self) -> GraphDef[tp.Any]: + return self._graphdef + + @property + def arg_states(self) -> tp.Sequence[VmapArgState]: + return (self,) + + +@struct.dataclass +class VmapStates(tp.Generic[A], extract.ExtractableStates): + arg_states: tuple[A, ...] = struct.field(pytree_node=True) + + @property + def index(self) -> int: + first = self.arg_states[0] + if not isinstance(first, VmapArgState): + raise RuntimeError( + f'Expected type VmapArgState, got {type(first)}, this is a bug.' + ) + return first.index + + @property + def graphdef(self) -> GraphDef[tp.Any]: + first = self.arg_states[0] + if not isinstance(first, VmapArgState): + raise RuntimeError( + f'Expected type VmapArgState, got {type(first)}, this is a bug.' + ) + return first.graphdef + + @property + def states(self) -> tp.Iterable[State]: + for arg_state in self.arg_states: + if not isinstance(arg_state, VmapArgState): + raise RuntimeError( + f'Expected type VmapArgState, got {type(arg_state)}, this is a bug.' + ) + yield arg_state.state + + +class StateAxes: + def __init__( + self, + filter_axes: tp.Mapping[filterlib.Filter, Index | None] + | tp.Iterable[tuple[filterlib.Filter, Index | None]], + /, + ): + iterable = ( + filter_axes.items() + if isinstance(filter_axes, tp.Mapping) + else filter_axes + ) + self.filters: tuple[filterlib.Filter, ...] = tuple( + filter for filter, _ in iterable + ) + self.axes: tuple[Index | None, ...] = tuple(axis for _, axis in iterable) + + def __repr__(self): + return f'StateAxes({dict(zip(self.filters, self.axes))})' + + def __eq__(self, other): + return ( + isinstance(other, StateAxes) + and self.filters == other.filters + and self.axes == other.axes + ) + + def __hash__(self): + return hash((self.filters, self.axes)) + + +@dataclasses.dataclass(frozen=True) +class VmapInputs: + f: tp.Callable[..., tp.Any] + transform_metadata: tp.Mapping[str, tp.Any] + in_axes: tp.Any + out_axes: tp.Any + + +def _index_to_state( + x: extract.Extractable, + *, + graphdef: GraphDef, + states: State, + axes: tuple[tp.Any, ...], +): + node_axis = axes[x.index] + node_state = states[x.index] if x.index in states else State({}) + assert isinstance(node_state, State) + if isinstance(node_axis, StateAxes): + substates = node_state.split(*node_axis.filters) + return VmapStates( + tuple( + VmapArgState(x.index, graphdef, substate, filter, axis) + for substate, filter, axis in zip( + substates, node_axis.filters, node_axis.axes + ) + ), + ) + else: + return VmapArgState( + x.index, graphdef, node_state, filterlib.Everything(), node_axis + ) + + +def vmap_fn( + args: tuple[tp.Any, ...], + vmap_inputs: VmapInputs, +): + f = vmap_inputs.f + transform_metadata = vmap_inputs.transform_metadata + extracted_states: tuple[VmapStates | VmapArgState, ...] = ( + extract.extract_indexes(args, types=(VmapStates, VmapArgState)) + ) + ctx = graph.current_update_context('vmap') + + # remove metadata axis name from Variable.sharding + def remove_axis_fn(arg_state): + if ( + isinstance(arg_state, VmapArgState) + and arg_state.axis is not None + and spmd.PARTITION_NAME in transform_metadata + ): + state = arg_state.state + state = spmd.remove_axis(state, arg_state.axis, transform_metadata) + return arg_state.replace(state=state) + return arg_state + + extracted_states = jax.tree.map(remove_axis_fn, extracted_states) + + if extracted_states: + graphdef, states = extract.merge_extractable_states(extracted_states) + inputs_graph_nodes = ctx.merge(graphdef, states) + args = extract.insert_graph_nodes(args, inputs_graph_nodes) + else: + inputs_graph_nodes = () + + out = f(*args) + + (args_out, out), output_nodes, output_node_axis = extract.extract_graph_nodes( + (args, out), prefix=(vmap_inputs.in_axes, vmap_inputs.out_axes) + ) + extract.check_consistent_aliasing(output_nodes, output_node_axis) + + graphdef_out, states_out = ctx.split(output_nodes) + + # add metadata axis name to Variable.sharding + if spmd.PARTITION_NAME in transform_metadata: + for index in states_out: + assert isinstance(index, int) + if output_node_axis[index] is not None: + states_out[index] = spmd.add_axis( + states_out[index], output_node_axis[index], transform_metadata + ) + + replace_fn = functools.partial( + _index_to_state, + graphdef=graphdef_out, + states=states_out, + axes=output_node_axis, + ) + out = extract.replace_indexes(out, replace_fn) + args_out = extract.replace_indexes(args_out, replace_fn, clear=True) + + return args_out, out + + +@tp.overload +def vmap( + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... + + +@tp.overload +def vmap( + f: F, + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F: ... + + +def vmap( + f: F | Missing = MISSING, + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F | tp.Callable[[F], F]: + """Vectorizing map. Creates a function which maps ``f`` over argument axes. + + Args: + f: Function to be mapped over additional axes. + in_axes: An integer, None, or sequence of values specifying which input + array axes to map over. + + If each positional argument to ``f`` is an array, then ``in_axes`` can + be an integer, a None, or a tuple of integers and Nones with length equal + to the number of positional arguments to ``f``. An integer or ``None`` + indicates which array axis to map over for all arguments (with ``None`` + indicating not to map any axis), and a tuple indicates which axis to map + for each corresponding positional argument. Axis integers must be in the + range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of + dimensions (axes) of the corresponding input array. + + If the positional arguments to ``f`` are container (pytree) types, ``in_axes`` + must be a sequence with length equal to the number of positional arguments to + ``f``, and for each argument the corresponding element of ``in_axes`` can + be a container with a matching pytree structure specifying the mapping of its + container elements. In other words, ``in_axes`` must be a container tree prefix + of the positional argument tuple passed to ``f``. See this link for more detail: + https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees + + Either ``axis_size`` must be provided explicitly, or at least one + positional argument must have ``in_axes`` not None. The sizes of the + mapped input axes for all mapped positional arguments must all be equal. + + Arguments passed as keywords are always mapped over their leading axis + (i.e. axis index 0). + + See below for examples. + + out_axes: An integer, None, or (nested) standard Python container + (tuple/list/dict) thereof indicating where the mapped axis should appear + in the output. All outputs with a mapped axis must have a non-None + ``out_axes`` specification. Axis integers must be in the range ``[-ndim, + ndim)`` for each output array, where ``ndim`` is the number of dimensions + (axes) of the array returned by the :func:`vmap`-ed function, which is one + more than the number of dimensions (axes) of the corresponding array + returned by ``f``. + axis_name: Optional, a hashable Python object used to identify the mapped + axis so that parallel collectives can be applied. + axis_size: Optional, an integer indicating the size of the axis to be + mapped. If not provided, the mapped axis size is inferred from arguments. + + Returns: + Batched/vectorized version of ``f`` with arguments that correspond to + those of ``f``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``f``, but + with extra array axes at positions indicated by ``out_axes``. + + + """ + if isinstance(f, Missing): + return functools.partial( + vmap, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + transform_metadata=transform_metadata, + ) + + jax_in_axes = jax.tree.map( + lambda x: VmapStates(x.axes) if isinstance(x, StateAxes) else x, + in_axes, + ) + jax_out_axes = jax.tree.map( + lambda x: VmapStates(x.axes) if isinstance(x, StateAxes) else x, + out_axes, + ) + + vmapped_fn = jax.vmap( + vmap_fn, + in_axes=( + jax_in_axes, # args + None, # vmap_inputs + ), + out_axes=( + jax_in_axes, # args_out + jax_out_axes, # out_axes + ), + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + ) + + @functools.wraps(f) + @graph.update_context('vmap') + def vmap_wrapper(*args): + ctx = graph.current_update_context('vmap') + + args, input_graph_nodes, input_node_axis = extract.extract_graph_nodes( + args, prefix=in_axes + ) + extract.check_consistent_aliasing(input_graph_nodes, input_node_axis) + graphdef, states = ctx.split(input_graph_nodes) + args = extract.replace_indexes( + args, + functools.partial( + _index_to_state, graphdef=graphdef, states=states, axes=input_node_axis + ), + ) + + args_out, out = vmapped_fn( + args, + VmapInputs( + f, + transform_metadata, + in_axes, + out_axes, + ), + ) + + extracted_states_out = extract.extract_indexes( + (args_out, out), types=(VmapStates, VmapArgState) + ) + if extracted_states_out: + graphdef_out, states_out = extract.merge_extractable_states( + extracted_states_out + ) + output_nodes = ctx.merge(graphdef_out, states_out) + out = extract.insert_graph_nodes(out, output_nodes) + + return out + + return vmap_wrapper # type: ignore + diff --git a/flax/nnx/nnx/transforms/general.py b/flax/nnx/nnx/transforms/general.py new file mode 100644 index 0000000000..6ed70a8c4a --- /dev/null +++ b/flax/nnx/nnx/transforms/general.py @@ -0,0 +1,165 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import typing as tp + +from flax import struct +from flax.nnx.nnx import ( + extract, + graph, +) +from flax.nnx.nnx.module import GraphDef +from flax.nnx.nnx.state import State + +A = tp.TypeVar('A') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) + +class Missing: + pass + + +MISSING = Missing() + +# ------------------------------- +# (split|merge)_inputs +# ------------------------------- + + +class ArgState(extract.ExtractionIndex, extract.ExtractableStates): + _graphdef: GraphDef[tp.Any] = struct.field(pytree_node=False) + state: State = struct.field(pytree_node=True) + + @property + def graphdef(self) -> GraphDef[tp.Any]: + return self._graphdef + + @property + def states(self) -> tp.Iterable[State]: + yield self.state + +@tp.overload +def split_inputs( + *, ctx_tag: str = 'split_merge_inputs' +) -> tp.Callable[[F], F]: ... +@tp.overload +def split_inputs(f: F, *, ctx_tag: str = 'split_merge_inputs') -> F: ... +def split_inputs( + f: F | Missing = MISSING, *, ctx_tag: str = 'split_merge_inputs' +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial(split_inputs, ctx_tag=ctx_tag) + + @graph.update_context(ctx_tag) + @functools.wraps(f) + def split_inputs_wrapper(*args): + ctx = graph.current_update_context(ctx_tag) + args, input_graph_nodes = extract.extract_graph_nodes(args) + graphdef, states = ctx.split(input_graph_nodes) + + args = extract.replace_indexes( + args, + lambda x: ArgState( + x.index, + graphdef, + states[x.index], # type: ignore + ), + ) + + args_out, out = f(*args) + arg_states_out = extract.extract_indexes((args_out, out), types=ArgState) + + if arg_states_out: + graphdef_out, states_out = extract.merge_extractable_states( + arg_states_out + ) + output_nodes = ctx.merge(graphdef_out, states_out) + out = extract.insert_graph_nodes(out, output_nodes) + + return out + + return split_inputs_wrapper # type: ignore + +@tp.overload +def merge_inputs( + *, + ctx_tag: str = 'split_merge_inputs', +) -> tp.Callable[[F], F]: ... +@tp.overload +def merge_inputs( + f: F, + *, + ctx_tag: str = 'split_merge_inputs', +) -> F: ... +def merge_inputs( + f: F | Missing = MISSING, + *, + ctx_tag: str = 'split_merge_inputs', +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial(merge_inputs, ctx_tag=ctx_tag) + + @functools.wraps(f) + def merge_inputs_wrapper(*args): + input_args = args + ctx = graph.current_update_context(ctx_tag) + arg_states = extract.extract_indexes(args, types=ArgState) + + if arg_states: + graphdef = arg_states[0].graphdef + states: graph.GraphState = State( + {x.index: x.state.raw_mapping for x in arg_states} + ) + inputs_graph_nodes = ctx.merge(graphdef, states) + args = extract.insert_graph_nodes(args, inputs_graph_nodes) + else: + inputs_graph_nodes = () + + out = f(*args) + + (_, out), output_graph_nodes = extract.extract_graph_nodes( + (inputs_graph_nodes, out) + ) + + graphdef_out, states_out = ctx.split(output_graph_nodes) + + def replace_index(x: extract.ExtractionIndex): + return ArgState( + x.index, + graphdef_out, + states_out[x.index], # type: ignore + ) + + out = extract.replace_indexes(out, replace_index) + input_args_out = extract.replace_indexes( + input_args, replace_index, clear=True + ) + + return input_args_out, out + + return merge_inputs_wrapper # type: ignore diff --git a/flax/nnx/nnx/transforms/looping.py b/flax/nnx/nnx/transforms/looping.py index 3c205bf1f1..657d070ba6 100644 --- a/flax/nnx/nnx/transforms/looping.py +++ b/flax/nnx/nnx/transforms/looping.py @@ -34,7 +34,7 @@ from flax import struct from flax.core.frozen_dict import FrozenDict -from flax.nnx.nnx import filterlib, graph, rnglib, spmd +from flax.nnx.nnx import extract, filterlib, graph, rnglib, spmd from flax.nnx.nnx.module import GraphDef, Module from flax.nnx.nnx.proxy_caller import DelayedAccessor from flax.nnx.nnx.state import State @@ -60,6 +60,12 @@ Leaves = tp.List[Leaf] Index = int +class Missing: + pass + + +MISSING = Missing() + # ------------------------------- # scan # ------------------------------- @@ -254,7 +260,7 @@ def scan_fn( input_graph_nodes = ctx.merge( graphdef, *scan_states, carry_state, split_rng_state, broadcast_rng_state ) - (args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes) + (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) out = f(*args, **kwargs) @@ -271,10 +277,9 @@ def scan_fn( carry_arg_out = out scan_args_out = None - ( - (carry_arg_out, scan_args_out), - output_graph_nodes, - ) = graph.extract_graph_nodes((carry_arg_out, scan_args_out)) + ((carry_arg_out, scan_args_out), output_graph_nodes) = ( + extract.extract_graph_nodes((carry_arg_out, scan_args_out)) + ) # split module state ( @@ -330,7 +335,25 @@ def _extract_carry_state(state: State, /): return carry_out, scan_out - +@tp.overload +def scan( + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 0, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + scan_output: bool = True, +) -> p.Callable[[F], F]: ... +@tp.overload def scan( f: F, *, @@ -348,12 +371,35 @@ def scan( split_rngs: filterlib.Filter = ..., transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), scan_output: bool = True, -) -> F: +) -> F: ... +def scan( + f: F | Missing = MISSING, + *, + length: int | None = None, + reverse: bool = False, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 0, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + scan_output: bool = True, +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + scan, length=length, reverse=reverse, unroll=unroll + ) + @functools.wraps(f) @graph.update_context('scan') def scan_apply_wrapper(*args, **kwargs): # extract nodes - (args, kwargs), input_graph_nodes = graph.extract_graph_nodes( + (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( (args, kwargs) ) input_rng_streams = rnglib.backup_keys(input_graph_nodes) @@ -465,11 +511,11 @@ def scan_apply_wrapper(*args, **kwargs): broadcast_rng_state_out, ) - carry_arg_out, scan_args_out = graph.insert_graph_nodes( + carry_arg_out, scan_args_out = extract.insert_graph_nodes( (carry_arg_out, scan_args_out), output_graph_nodes ) - rnglib.restore_keys(input_rng_streams) + rnglib.restore_rngs(input_rng_streams) if scan_output: scan_args_out = tp.cast(B, scan_args_out) diff --git a/flax/nnx/nnx/transforms/parallelization.py b/flax/nnx/nnx/transforms/parallelization.py index 3e0fe74ba8..099fa7d16c 100644 --- a/flax/nnx/nnx/transforms/parallelization.py +++ b/flax/nnx/nnx/transforms/parallelization.py @@ -40,6 +40,7 @@ from flax import struct from flax.core.frozen_dict import FrozenDict from flax.nnx.nnx import ( + extract, filterlib, graph, rnglib, @@ -68,6 +69,12 @@ AxesValue = tp.Union[int, None] SplitPattern = tp.Union[AxesValue, tuple[AxesValue, ...]] +class Missing: + pass + + +MISSING = Missing() + # ------------------------------- # vmap @@ -194,11 +201,11 @@ def vmap_fn( broadcast_counts, ) - (args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes) + (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) out = f(*args, **kwargs) - out, output_graph_nodes = graph.extract_graph_nodes(out) + out, output_graph_nodes = extract.extract_graph_nodes(out) # split module state ( @@ -231,7 +238,21 @@ def vmap_fn( out, ) - +@tp.overload +def vmap( + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... +@tp.overload def vmap( f: F, *, @@ -245,9 +266,36 @@ def vmap( state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}), split_rngs: filterlib.Filter = ..., transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), -) -> F: - vectorized_states_axes = list(state_axes.values()) +) -> F: ... +def vmap( + f: F | Missing = MISSING, + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + vmap, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) + vectorized_states_axes = list(state_axes.values()) vmapped_fn = jax.vmap( vmap_fn, in_axes=( @@ -283,7 +331,7 @@ def vmap( def vmap_wrapper(*args, **kwargs): ctx = graph.current_update_context('vmap') - (args, kwargs), input_graph_nodes = graph.extract_graph_nodes( + (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( (args, kwargs) ) input_rng_streams = _backup_vmap_keys(input_graph_nodes) @@ -359,7 +407,7 @@ def vmap_wrapper(*args, **kwargs): split_keys_out, ) - out = graph.insert_graph_nodes(out, output_graph_nodes) + out = extract.insert_graph_nodes(out, output_graph_nodes) _restore_vmap_keys(input_rng_streams, split_rngs) @@ -519,11 +567,11 @@ def pmap_fn( broadcast_counts, ) - (args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes) + (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) out = f(*args, **kwargs) - out, output_graph_nodes = graph.extract_graph_nodes(out) + out, output_graph_nodes = extract.extract_graph_nodes(out) # split module state ( @@ -560,11 +608,47 @@ def pmap_fn( out, ) - +@tp.overload +def pmap( + *, + axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> tp.Callable[[F], F]: ... +@tp.overload def pmap( f: F, + *, axis_name: AxisName | None = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + static_broadcasted_argnums: int | tp.Iterable[int] = (), + devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 + backend: str | None = None, + axis_size: int | None = None, + donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), +) -> F: ... +def pmap( + f: F | Missing = MISSING, *, + axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), @@ -578,7 +662,24 @@ def pmap( state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), split_rngs: filterlib.Filter = ..., transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), -) -> F: +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + pmap, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + ) if static_broadcasted_argnums: raise NotImplementedError( 'static_broadcasted_argnums is not yet supported in nnx.pmap' @@ -625,7 +726,7 @@ def pmap( def pmap_wrapper(*args, **kwargs): ctx = graph.current_update_context('pmap') - (args, kwargs), input_graph_nodes = graph.extract_graph_nodes( + (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( (args, kwargs) ) input_rng_streams = rnglib.backup_keys(input_graph_nodes) @@ -701,9 +802,9 @@ def pmap_wrapper(*args, **kwargs): split_keys_out, ) - out = graph.insert_graph_nodes(out, output_graph_nodes) + out = extract.insert_graph_nodes(out, output_graph_nodes) - rnglib.restore_keys(input_rng_streams) + rnglib.restore_rngs(input_rng_streams) return out @@ -776,8 +877,7 @@ def __init__( ): self.module_constructor = module_constructor - @functools.partial( - pmap, + @pmap( axis_name=axis_name, in_axes=None, out_axes=None, @@ -797,8 +897,7 @@ def pmap_init(*args, **kwargs): self.pmap_module = pmap_init(*module_init_args, **module_init_kwargs) - @functools.partial( - pmap, + @pmap( axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, diff --git a/flax/nnx/nnx/transforms/transforms.py b/flax/nnx/nnx/transforms/transforms.py index a9471b48b2..0317abbd81 100644 --- a/flax/nnx/nnx/transforms/transforms.py +++ b/flax/nnx/nnx/transforms/transforms.py @@ -34,6 +34,7 @@ import typing as tp from flax.nnx.nnx import ( + extract, filterlib, graph, spmd, @@ -45,6 +46,7 @@ DelayedAccessor, ) from flax.nnx.nnx.state import State +from flax.nnx.nnx.transforms import general from flax.typing import Leaf import jax import jax.core @@ -63,6 +65,12 @@ Leaves = tp.List[Leaf] Index = int +class Missing: + pass + + +MISSING = Missing() + def _normalize_sequence( x: StrInt | tp.Iterable[StrInt] | None, / @@ -153,11 +161,11 @@ def jit_fn( input_graph_nodes = ctx.merge(graphdef, state) - (args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes) + (args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes) out = f(*args, **kwargs) - out, output_graph_nodes = graph.extract_graph_nodes(out) + out, output_graph_nodes = extract.extract_graph_nodes(out) graphdef, state = ctx.split((input_graph_nodes, output_graph_nodes)) @@ -167,6 +175,25 @@ def jit_fn( return out, state, graphdef +@tp.overload +def jit( + *, + in_shardings: tp.Any = UNSPECIFIED, + out_shardings: tp.Any = UNSPECIFIED, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + # nnx specific + donate_state: bool = False, + constrain_state: bool | tp.Callable[[State], State] = False, +) -> tp.Callable[[F], F]: ... +@tp.overload def jit( fun: F, *, @@ -184,7 +211,25 @@ def jit( # nnx specific donate_state: bool = False, constrain_state: bool | tp.Callable[[State], State] = False, -) -> F: +) -> F: ... +def jit( + fun: F | Missing = MISSING, + *, + in_shardings: tp.Any = UNSPECIFIED, + out_shardings: tp.Any = UNSPECIFIED, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + # nnx specific + donate_state: bool = False, + constrain_state: bool | tp.Callable[[State], State] = False, +) -> F | tp.Callable[[F], F]: """ Lifted version of ``jax.jit`` that can handle Modules / graph nodes as arguments. @@ -313,6 +358,23 @@ def jit( A wrapped version of ``fun``, set up for just-in-time compilation. """ + if isinstance(fun, Missing): + return functools.partial( + jit, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + donate_state=donate_state, + constrain_state=constrain_state, + ) _static_argnums = _normalize_sequence(static_argnums) _static_argnames = _normalize_sequence(static_argnames) _donate_argnums = _normalize_sequence(donate_argnums) @@ -352,7 +414,7 @@ def jit( @graph.update_context('jit') def jit_wrapper(*args, **kwargs): ctx = graph.current_update_context('jit') - (args, kwargs), input_graph_nodes = graph.extract_graph_nodes( + (args, kwargs), input_graph_nodes = extract.extract_graph_nodes( (args, kwargs) ) graphdef, state = ctx.split(input_graph_nodes) @@ -365,7 +427,7 @@ def jit_wrapper(*args, **kwargs): input_graph_nodes, output_graph_nodes = ctx.merge( output_graphdef, output_state ) - out = graph.insert_graph_nodes(out, output_graph_nodes) + out = extract.insert_graph_nodes(out, output_graph_nodes) return out jit_wrapper.inner = jitted_fn # type: ignore @@ -507,11 +569,11 @@ def grad_fn(*args): args[i] = arg # add other nodes to the args - args = graph.insert_graph_nodes(args, input_nodes) + args = extract.insert_graph_nodes(args, input_nodes) out = f(*args) - out, out_nodes = graph.extract_graph_nodes(out) + out, out_nodes = extract.extract_graph_nodes(out) graphdef_out, state_out = ctx.split((input_nodes, out_nodes)) @@ -543,7 +605,7 @@ def grad_wrapper(*args): for i, arg in enumerate(args) if i in _argnums and graph.is_node(arg) } - args, input_nodes = graph.extract_graph_nodes(args) + args, input_nodes = extract.extract_graph_nodes(args) args = list(args) def only_diff(path: tuple, value: tp.Any) -> bool: @@ -590,7 +652,7 @@ def only_diff(path: tuple, value: tp.Any) -> bool: input_nodes, out_nodes = ctx.merge(graphdef_out, state_out) - out = graph.insert_graph_nodes(out, out_nodes) + out = extract.insert_graph_nodes(out, out_nodes) return out return grad_wrapper @@ -859,15 +921,15 @@ def remat_apply( args: tuple[tp.Any, ...], ): ctx = graph.current_update_context('remat') - args, input_nodes = graph.extract_graph_nodes(args) + args, input_nodes = extract.extract_graph_nodes(args) graphdef, state = ctx.split(input_nodes) def _remat_fn(state: State, *args): input_nodes = ctx.merge(graphdef, state) - args = graph.insert_graph_nodes(args, input_nodes) + args = extract.insert_graph_nodes(args, input_nodes) out = f(*args) - out, output_nodes = graph.extract_graph_nodes(out) + out, output_nodes = extract.extract_graph_nodes(out) new_graphdef, new_state = ctx.split((input_nodes, output_nodes)) return (new_graphdef, new_state), out @@ -879,18 +941,40 @@ def _remat_fn(state: State, *args): )(state, *args) _, output_nodes = ctx.merge(new_graphdef, new_state) - out = graph.insert_graph_nodes(out, output_nodes) + out = extract.insert_graph_nodes(out, output_nodes) return out - +@tp.overload +def remat( + *, + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, +) -> tp.Callable[[F], F]: ... +@tp.overload def remat( f: F, *, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, -) -> F: +) -> F: ... +def remat( + f: F | Missing = MISSING, + *, + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, +) -> F | tp.Callable[[F], F]: + if isinstance(f, Missing): + return functools.partial( + remat, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + options = RematOptions( prevent_cse=prevent_cse, static_argnums=static_argnums, @@ -914,15 +998,15 @@ def eval_shape( *args: tp.Any, **kwargs: tp.Any, ) -> A: - (args, kwargs), input_nodes = graph.extract_graph_nodes((args, kwargs)) + (args, kwargs), input_nodes = extract.extract_graph_nodes((args, kwargs)) graphdef, state = graph.split(input_nodes) @functools.wraps(f) def _eval_shape_fn(state: State, *args, **kwargs): input_nodes = graph.merge(graphdef, state) - args, kwargs = graph.insert_graph_nodes((args, kwargs), input_nodes) + args, kwargs = extract.insert_graph_nodes((args, kwargs), input_nodes) out = f(*args, **kwargs) - out, output_nodes = graph.extract_graph_nodes(out) + out, output_nodes = extract.extract_graph_nodes(out) graphdef_out, state_out = graph.split(output_nodes) return graphdef_out, state_out, out @@ -931,7 +1015,7 @@ def _eval_shape_fn(state: State, *args, **kwargs): ) output_nodes = graph.merge(graphdef_out, state_out) - out = graph.insert_graph_nodes(out, output_nodes) + out = extract.insert_graph_nodes(out, output_nodes) return out @@ -939,47 +1023,7 @@ def _eval_shape_fn(state: State, *args, **kwargs): # cond # ------------------------------- - -@dataclasses.dataclass(frozen=True) -class CondStaticInputs(tp.Generic[A]): - true_fun: tp.Callable[..., A] - false_fun: tp.Callable[..., A] - - -jax.tree_util.register_static(CondStaticInputs) - - -def _cond_fun( - is_true: bool, - static_inputs: CondStaticInputs[A], - graphdef: GraphDef[tuple[tp.Any, ...]], - state: State, -): - ctx = graph.current_update_context('cond') - fn = static_inputs.true_fun if is_true else static_inputs.false_fun - operands = ctx.merge(graphdef, state) - out = fn(*operands) - graphdef_out, state_out = ctx.split((operands, out)) - return graphdef_out, state_out - - -def _cond_true_fun( - static_inputs: CondStaticInputs[A], - graphdef: GraphDef[tuple[tp.Any, ...]], - state: State, -): - return _cond_fun(True, static_inputs, graphdef, state) - - -def _cond_false_fun( - static_inputs: CondStaticInputs[A], - graphdef: GraphDef[tuple[tp.Any, ...]], - state: State, -): - return _cond_fun(False, static_inputs, graphdef, state) - - -@graph.update_context('cond') +@general.split_inputs(ctx_tag='cond') def cond( pred, true_fun: tp.Callable[..., A], @@ -987,16 +1031,10 @@ def cond( *operands, **kwargs, ) -> A: - ctx: graph.UpdateContext = graph.current_update_context('cond') - graphdef, state = ctx.split(operands) - graphdef_out, state_out = jax.lax.cond( + return jax.lax.cond( pred, - _cond_true_fun, - _cond_false_fun, - CondStaticInputs(true_fun=true_fun, false_fun=false_fun), - graphdef, - state, + general.merge_inputs(true_fun, ctx_tag='cond'), + general.merge_inputs(false_fun, ctx_tag='cond'), + *operands, **kwargs, ) - _operands_out, out = ctx.merge(graphdef_out, state_out) - return out diff --git a/flax/nnx/tests/experimental_test.py b/flax/nnx/tests/experimental_test.py new file mode 100644 index 0000000000..c2d8e4527f --- /dev/null +++ b/flax/nnx/tests/experimental_test.py @@ -0,0 +1,293 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from absl.testing import absltest +import jax +import jax.numpy as jnp +import numpy as np + +from flax import nnx + + +class TestExperimentalVmap(absltest.TestCase): + def test_basic(self): + @partial(nnx.experimental_vmap, in_axes=0, out_axes=0, axis_size=5) + def create_block(rngs: nnx.Rngs): + return nnx.Linear(2, 3, rngs=rngs) + + rngs = nnx.Rngs(0) + backups = nnx.split_rngs(rngs, 5) + + block = create_block(rngs) + nnx.restore_rngs(backups) + + self.assertEqual(block.kernel.value.shape, (5, 2, 3)) + self.assertEqual(rngs.default.count.value, 1) + + @partial(nnx.experimental_vmap, in_axes=(0, 1), out_axes=1) + def forward(block: nnx.Linear, x): + self.assertEqual(block.kernel.value.shape, (2, 3)) + self.assertEqual(block.bias.value.shape, (3,)) + self.assertEqual(x.shape, (2,)) + return block(x) + + x = jax.random.uniform(rngs(), (2, 5)) + y = forward(block, x) + + self.assertEqual(y.shape, (3, 5)) + + def test_state_axes(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x + + @nnx.experimental_vmap( + in_axes=0, + out_axes=nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), + ) + def create_block(rngs: nnx.Rngs): + rngs = nnx.clone(rngs) + return Block(rngs) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + + backups = nnx.split_rngs(rngs, 5) + module = create_block(rngs) + nnx.restore_rngs(backups) + + assert rngs.default.count.value == 1 + assert rngs.default.key.value == initial_key + assert not jnp.allclose( + module.linear.kernel.value[0], + module.linear.kernel.value[1], + ) + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + + x = jnp.ones((5, 1, 3)) + + @nnx.experimental_vmap( + in_axes=(nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), 0), + ) + def forward_block(module, x): + return module(x) + + backups = nnx.split_rngs(rngs, 5) + y = forward_block(module, x) + nnx.restore_rngs(backups) + + assert y.shape == (5, 1, 3) + assert rngs.default.count.value == 2 + assert rngs.default.key.value == initial_key + + y2 = forward_block(module, x) + + assert not jnp.allclose(y, y2) + + def test_state_axes_simple(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) + + @nnx.experimental_vmap(in_axes=(state_axes,), out_axes=state_axes) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + rngs = nnx.Rngs(params=0, dropout=1) + nnx.split_rngs(rngs, 5, 'dropout') + + module = create_block(rngs) + + assert module.linear.kernel.value.shape == (2, 3) + assert module.bn.scale.value.shape == (3,) + assert module.bn.mean.value.shape == (5, 3) + + @nnx.experimental_vmap(in_axes=(state_axes, 0), out_axes=0) + def forward_block(module, x): + return module(x) + + x = jnp.ones((5, 1, 2)) + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + + def test_state_axes_super_simple(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + @nnx.experimental_vmap(in_axes=0, out_axes=0) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + rngs = nnx.Rngs(0) + nnx.split_rngs(rngs, 5) + + module = create_block(rngs) + + assert module.linear.kernel.value.shape == (5, 2, 3) + assert module.bn.scale.value.shape == (5, 3) + assert module.bn.mean.value.shape == (5, 3) + + @nnx.experimental_vmap(in_axes=(0, 0), out_axes=0) + def forward_block(module, x): + return module(x) + + x = jnp.ones((5, 1, 2)) + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + + def test_consistent_aliasing_inputs(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.zeros((5, 5)) + + m = Foo() + + @nnx.experimental_vmap(in_axes=(0, 1)) + def f(m1, m2): + pass + + with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): + f(m, m) + + def test_consistent_aliasing_input_output(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.zeros((2, 3)) + + m = Foo() + + @partial(nnx.experimental_vmap, in_axes=0, out_axes=1) + def f(m): + return m + + with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): + m2 = f(m) + + def test_consistent_aliasing_shared(self): + class Shared(nnx.Module): + def __init__(self): + self.a = jnp.zeros((3, 3)) + + class Foo(nnx.Module): + def __init__(self, shared: Shared): + self.a = shared + + shared = Shared() + m1 = Foo(shared) + m2 = Foo(shared) + + @partial(nnx.experimental_vmap, in_axes=(0, 1)) + def f(m1, m2): + pass + + with self.assertRaisesRegex( + ValueError, + r'Inconsistent aliasing detected([\s\S]*)Shared([\s\S]*)a: 0([\s\S]*)a: 1', + ): + f(m1, m2) + + def test_vmap_and_cond_passthrough(self): + class Broadcast(nnx.Variable[nnx.A]): ... + + class Vectorized(nnx.Variable[nnx.A]): ... + + class Env(nnx.Module): + def __init__(self): + self.broadcast = Broadcast(jnp.array(1)) + self.index = Vectorized(jnp.arange(8)) + self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) + + env = Env() + + @nnx.experimental_vmap( + in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),) + ) + def f(env: Env): + self.assertEqual(env.step.shape, ()) + + def increment(env: Env): + env.step += 1 + + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond(is_even, increment, no_nothing, env) + + f(env) + + np.testing.assert_array_equal(env.step.value, [1, 0, 1, 0, 1, 0, 1, 0]) + + def test_vmap_and_cond_passthrough_error(self): + class Broadcast(nnx.Variable[nnx.A]): ... + + class Vectorized(nnx.Variable[nnx.A]): ... + + class Env(nnx.Module): + def __init__(self): + self.broadcast = Broadcast(jnp.array(1)) + self.index = Vectorized(jnp.arange(8)) + self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) + + env = Env() + + @nnx.experimental_vmap( + in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),) + ) + def f(env: Env): + self.assertEqual(env.step.shape, ()) + + def increment(env: Env): + env.step += 1 + env.broadcast += 1 + + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond(is_even, increment, no_nothing, env) + + with self.assertRaisesRegex( + ValueError, + r"at vmap.*'broadcast'.*got axis spec None but output was batched on axis 0", + ): + f(env) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/flax/nnx/tests/transforms_test.py b/flax/nnx/tests/transforms_test.py index 31903d246f..4537752cb7 100644 --- a/flax/nnx/tests/transforms_test.py +++ b/flax/nnx/tests/transforms_test.py @@ -1424,6 +1424,122 @@ def reward_0(self: Foo): assert foo.timestep.step == 4 assert foo.timestep.reward == 0.0 + def test_cond_and_vmap(self): + class Env(nnx.Module): + def __init__(self): + self.index = jnp.arange(8) + self.step = jnp.zeros((8,), jnp.uint32) + + env = Env() + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.experimental_vmap(in_axes=(0, None), out_axes=None) + def f(env: Env, model: nnx.Linear): + self.assertEqual(env.index.shape, ()) + + def increment(env: Env): + env.step += 1 + + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond(is_even, increment, no_nothing, env) + + f(env, model) + + np.testing.assert_array_equal(env.step, [1, 0, 1, 0, 1, 0, 1, 0]) + + +class TestSplitMergeInputs(absltest.TestCase): + def test_split_inputs(self): + class StatefulLinear(nnx.Linear): + def __init__(self, din: int, dout: int, rngs: nnx.Rngs): + super().__init__(din, dout, rngs=rngs) + self.counter = jnp.array(0, jnp.uint32) + + def __call__(self, x): + self.counter += 1 + return super().__call__(x) + + model = StatefulLinear(3, 4, rngs=nnx.Rngs(0)) + + @nnx.split_inputs + @jax.jit + @nnx.merge_inputs + def forward(model, x): + return model(x) + + x = jnp.ones((2, 3)) + y = forward(model, x) + + self.assertEqual(model.counter, 1) + + def test_split_inputs_cond(self): + class Counter(nnx.Linear): + def __init__(self): + self.count = jnp.array(0, jnp.uint32) + + def increment(self): + self.count += 1 + + counter = Counter() + + @nnx.merge_inputs + def increment(counter: Counter): + counter.increment() + + @nnx.merge_inputs + def no_nothing(counter: Counter): + pass + + nnx.split_inputs(jax.lax.cond)(True, increment, no_nothing, counter) + + self.assertEqual(counter.count, 1) + + nnx.split_inputs(jax.lax.cond)(False, increment, no_nothing, counter) + + self.assertEqual(counter.count, 1) + + def test_split_inputs_vmap(self): + class EnvState(nnx.Variable[nnx.A]): + pass + + class Env(nnx.Object): + def __init__(self): + self.index = EnvState(jnp.arange(8)) + self.step = EnvState(jnp.zeros((8,), jnp.uint32)) + + env = Env() + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + # internally merge_inputs returns (args, out) + in_axes = (0, None) + out_axes = (in_axes, None) + + @nnx.split_inputs + @partial(jax.vmap, in_axes=in_axes, out_axes=out_axes) + @nnx.merge_inputs + def f(env: Env, model: nnx.Linear): + self.assertEqual(env.index.value.shape, ()) + + @nnx.merge_inputs + def increment(env: Env): + env.step.value += 1 + + @nnx.merge_inputs + def no_nothing(env: Env): + pass + + is_even = env.index.value % 2 == 0 + nnx.split_inputs(jax.lax.cond)(is_even, increment, no_nothing, env) + + f(env, model) + + np.testing.assert_array_equal( + env.step.value, np.array([1, 0, 1, 0, 1, 0, 1, 0], np.uint32) + ) + if __name__ == '__main__': absltest.main()