diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 1a49c6655f..5468a5a987 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -594,7 +594,7 @@ def _get_children(): raise ValueError( f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) - variable.copy_from_state(value) + variable.update_from_state(value) else: # if it doesn't, create a new variable assert isinstance(value, VariableState) variable = value.to_variable() @@ -729,7 +729,7 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' ) - current_value.copy_from_state(value) + current_value.update_from_state(value) elif is_state_leaf(value): # case 4: state field is being updated if isinstance(node_impl, PytreeNodeImpl): diff --git a/flax/nnx/object.py b/flax/nnx/object.py index addee066d6..9e14155108 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -204,7 +204,7 @@ def _graph_node_set_key(self, key: Key, value: tp.Any): and isinstance(variable := getattr(self, key), Variable) and isinstance(value, VariableState) ): - variable.copy_from_state(value) + variable.update_from_state(value) else: setattr(self, key, value) diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 909f5d3b07..e18003276b 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -17,7 +17,7 @@ import jax from jax.interpreters import pxla -from jax.sharding import Mesh, PartitionSpec +from jax.sharding import PartitionSpec from flax.nnx import variables from flax.typing import ( @@ -159,57 +159,14 @@ def with_sharding_constraint( ) -# ------------------------------------- -# Partitioning Axis Metadata -# ------------------------------------- - - -@tp.runtime_checkable -class Partitioned(tp.Protocol): - get_value_hooks: tp.Callable[[variables.Variable[tp.Any]], tp.Any] - sharding: Sharding - mesh: tp.Optional[Mesh] - - -def sharding_hook( - node: variables.Variable[tp.Any], - value: tp.Any, - /, -) -> tp.Any: - if _global_mesh_defined() or ( - isinstance(node, Partitioned) and node.mesh is not None - ): - spec = get_partition_spec(node).raw_value - return with_sharding_constraint(value, spec, mesh=node.mesh) - return value - - def with_partitioning( initializer: F, sharding: Sharding, mesh: tp.Optional[jax.sharding.Mesh] = None, - get_value_hooks: tp.Union[ - variables.GetValueHook[A], tp.Sequence[variables.GetValueHook[A]] - ] = (), - create_value_hooks: tp.Union[ - variables.CreateValueHook[A], tp.Sequence[variables.CreateValueHook[A]] - ] = (), **metadata: tp.Any, ) -> F: - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks, sharding_hook) - else: - get_value_hooks = (*get_value_hooks, sharding_hook) - - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks, sharding_hook) - else: - create_value_hooks = (*create_value_hooks, sharding_hook) - return variables.with_metadata( initializer, - get_value_hooks=get_value_hooks, - create_value_hooks=create_value_hooks, sharding=sharding, mesh=mesh, **metadata, diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 33922bbc5d..281066ea42 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -13,23 +13,94 @@ # limitations under the License. from __future__ import annotations +import jax import jax.numpy as jnp import optax from flax import nnx -from flax.nnx import filterlib, graph +from flax.nnx import filterlib +from flax.nnx import variables from flax.nnx.object import Object -from flax.nnx.variables import Variable +from flax.nnx.variables import Variable, VariableState # TODO: add tests and docstrings class OptState(Variable): - """Wrapper class for Optimizer Variables.""" + """Holds any optimizer state.""" pass +class OptArray(OptState): + """Holds an array of optimizer state.""" + + pass + + +class OptVariable(OptState): + """Holds Variable state.""" + + source_type: type[Variable] + pass + + +def _wrap_optimizer_state(opt_state): + def wrap_optimizer_state_fn(x): + if isinstance(x, variables.VariableState): + new_state = x.copy() + new_state.source_type = x.type + new_state.type = OptVariable + return new_state.to_variable() + else: + return OptArray(x) + + return jax.tree.map( + wrap_optimizer_state_fn, + opt_state, + is_leaf=lambda x: isinstance(x, variables.VariableState), + ) + + +def _opt_state_variables_to_state(opt_state): + def optimizer_variable_to_state_fn(x): + if isinstance(x, OptVariable): + state = x.to_state() + state.type = x.source_type + del state.source_type + return state + elif isinstance(x, OptArray): + return x.value + else: + raise TypeError( + f'Unexpected type when converting optimizer state: {type(x)}' + ) + + return jax.tree.map(optimizer_variable_to_state_fn, opt_state) + + +def _update_opt_state(opt_state, updates): + def optimizer_update_variables(x, update): + if isinstance(x, OptVariable): + if not isinstance(update, VariableState): + raise TypeError( + f'Expected update to be VariableState, got {type(update)}' + ) + x.raw_value = update.value + elif isinstance(x, OptArray): + if isinstance(update, VariableState): + raise TypeError( + f'Expected update to not to be a VariableState, got {update}' + ) + x.raw_value = update + else: + raise TypeError( + f'Unexpected type when updating optimizer state: {type(x)}' + ) + + return jax.tree.map(optimizer_update_variables, opt_state, updates) + + class Optimizer(Object): """Simple train state for the common case with a single Optax optimizer. @@ -119,12 +190,9 @@ def __init__( self.step = OptState(jnp.array(0, dtype=jnp.uint32)) self.model = model self.tx = tx - self.opt_state = OptState(tx.init(nnx.state(model, wrt))) + self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt))) self.wrt = wrt - def split(self, *filters: filterlib.Filter): - return graph.split(self, *filters) - def update(self, grads): """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value. The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the @@ -182,12 +250,13 @@ def update(self, grads): Args: grads: the gradients derived from ``nnx.grad``. """ - state = nnx.state(self.model, self.wrt) + params = nnx.state(self.model, self.wrt) + opt_state = _opt_state_variables_to_state(self.opt_state) - updates, new_opt_state = self.tx.update(grads, self.opt_state.value, state) - new_params = optax.apply_updates(state, updates) + updates, new_opt_state = self.tx.update(grads, opt_state, params) + new_params = optax.apply_updates(params, updates) assert isinstance(new_params, nnx.State) self.step.value += 1 nnx.update(self.model, new_params) - self.opt_state.value = new_opt_state + _update_opt_state(self.opt_state, new_opt_state) diff --git a/flax/nnx/variables.py b/flax/nnx/variables.py index bd15e3d0b7..76805477f5 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variables.py @@ -20,6 +20,8 @@ import typing as tp from typing import Any +import jax + from flax import nnx from flax.nnx import reprlib, tracers from flax.typing import Missing @@ -128,6 +130,7 @@ class Variable(tp.Generic[A], reprlib.Representable): def __init__( self, value: tp.Union[A, VariableMetadata[A]], + *, set_value_hooks: tp.Union[ SetValueHook[A], tp.Sequence[SetValueHook[A]] ] = (), @@ -281,7 +284,7 @@ def copy_from(self, other: Variable[A]) -> None: vars_dict.clear() vars_dict.update(other_vars, _trace_state=trace_state) - def copy_from_state(self, variable_state: VariableState[A]): + def update_from_state(self, variable_state: VariableState[A]): trace_state = self._trace_state variable_vars = vars(self) variable_vars.clear() @@ -815,6 +818,12 @@ def __init__( self.value = value vars(self).update(metadata) + if tp.TYPE_CHECKING: + + def __getattr__(self, name: str) -> None: ... + def __setattr__(self, name: str, value: Any) -> None: ... + def __delattr__(self, name: str) -> None: ... + def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) @@ -852,6 +861,9 @@ def to_variable(self) -> Variable[A]: ) return variables + def copy(self: VariableState[A]) -> VariableState[A]: + return jax.tree.map(lambda x: x, self) + def get_metadata(self) -> dict[str, tp.Any]: metadata = vars(self).copy() del metadata['type'] diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index 1d11254114..68d43313d3 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -41,11 +41,45 @@ def test_split_merge(self, module_cls): x = jax.random.normal(jax.random.key(0), (1, 2)) model = module_cls(2, 4, rngs=nnx.Rngs(0)) tx = optax.adam(1e-3) - state = nnx.Optimizer(model, tx) - out = state.model(x) - graphdef, state = state.split() - state = nnx.merge(graphdef, state) - np.testing.assert_allclose(out, state.model(x)) + optimizer = nnx.Optimizer(model, tx) + out = optimizer.model(x) + graphdef, optimizer = nnx.split(optimizer) + optimizer = nnx.merge(graphdef, optimizer) + np.testing.assert_allclose(out, optimizer.model(x)) + + def test_update(self): + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adamw(0.1)) + + def loss_fn(model): + params = nnx.state(model) + loss = sum(jnp.sum(x**2) for x in jax.tree.leaves(params)) + return loss + + grads = nnx.grad(loss_fn)(model) + optimizer.update(grads) + + def test_sharding_propagation(self): + model = nnx.Linear( + 2, + 3, + rngs=nnx.Rngs(0), + kernel_init=nnx.with_partitioning( + nnx.initializers.lecun_normal(), + sharding=('a', 'b'), + ), + use_bias=False, + ) + optimizer = nnx.Optimizer(model, optax.adamw(0.1)) + + state = nnx.state(optimizer) + partition_spec = nnx.get_partition_spec(state) + + self.assertEqual(state.opt_state[0].mu.kernel.sharding, ('a', 'b')) + self.assertEqual( + partition_spec.opt_state[0].mu.kernel.value, + jax.sharding.PartitionSpec('a', 'b'), + ) @parameterized.product( module_cls=[nnx.Linear, Model], @@ -75,7 +109,9 @@ def jax_jit_train_step(graphdef, state, x, y): state.update(grads) return state.split() - graphdef, state = jit_decorator(jax_jit_train_step)(*state.split(), x, y) + graphdef, state = jit_decorator(jax_jit_train_step)( + *nnx.split(state), x, y + ) state = nnx.merge(graphdef, state) new_loss = loss_fn(*nnx.split(state.model), x, y)