Skip to content

Commit

Permalink
[nnx] improve Optimizer metadata propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 8, 2024
1 parent 71b964b commit ffe04bd
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 65 deletions.
4 changes: 2 additions & 2 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _get_children():
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
variable.copy_from_state(value)
variable.update_from_state(value)
else: # if it doesn't, create a new variable
assert isinstance(value, VariableState)
variable = value.to_variable()
Expand Down Expand Up @@ -729,7 +729,7 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]):
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
f'{value!r}'
)
current_value.copy_from_state(value)
current_value.update_from_state(value)
elif is_state_leaf(value):
# case 4: state field is being updated
if isinstance(node_impl, PytreeNodeImpl):
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _graph_node_set_key(self, key: Key, value: tp.Any):
and isinstance(variable := getattr(self, key), Variable)
and isinstance(value, VariableState)
):
variable.copy_from_state(value)
variable.update_from_state(value)
else:
setattr(self, key, value)

Expand Down
45 changes: 1 addition & 44 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import jax
from jax.interpreters import pxla
from jax.sharding import Mesh, PartitionSpec
from jax.sharding import PartitionSpec

from flax.nnx import variables
from flax.typing import (
Expand Down Expand Up @@ -159,57 +159,14 @@ def with_sharding_constraint(
)


# -------------------------------------
# Partitioning Axis Metadata
# -------------------------------------


@tp.runtime_checkable
class Partitioned(tp.Protocol):
get_value_hooks: tp.Callable[[variables.Variable[tp.Any]], tp.Any]
sharding: Sharding
mesh: tp.Optional[Mesh]


def sharding_hook(
node: variables.Variable[tp.Any],
value: tp.Any,
/,
) -> tp.Any:
if _global_mesh_defined() or (
isinstance(node, Partitioned) and node.mesh is not None
):
spec = get_partition_spec(node).raw_value
return with_sharding_constraint(value, spec, mesh=node.mesh)
return value


def with_partitioning(
initializer: F,
sharding: Sharding,
mesh: tp.Optional[jax.sharding.Mesh] = None,
get_value_hooks: tp.Union[
variables.GetValueHook[A], tp.Sequence[variables.GetValueHook[A]]
] = (),
create_value_hooks: tp.Union[
variables.CreateValueHook[A], tp.Sequence[variables.CreateValueHook[A]]
] = (),
**metadata: tp.Any,
) -> F:
if callable(get_value_hooks):
get_value_hooks = (get_value_hooks, sharding_hook)
else:
get_value_hooks = (*get_value_hooks, sharding_hook)

if callable(create_value_hooks):
create_value_hooks = (create_value_hooks, sharding_hook)
else:
create_value_hooks = (*create_value_hooks, sharding_hook)

return variables.with_metadata(
initializer,
get_value_hooks=get_value_hooks,
create_value_hooks=create_value_hooks,
sharding=sharding,
mesh=mesh,
**metadata,
Expand Down
91 changes: 80 additions & 11 deletions flax/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,94 @@
# limitations under the License.
from __future__ import annotations

import jax
import jax.numpy as jnp
import optax

from flax import nnx
from flax.nnx import filterlib, graph
from flax.nnx import filterlib
from flax.nnx import variables
from flax.nnx.object import Object
from flax.nnx.variables import Variable
from flax.nnx.variables import Variable, VariableState

# TODO: add tests and docstrings


class OptState(Variable):
"""Wrapper class for Optimizer Variables."""
"""Holds any optimizer state."""

pass


class OptArray(OptState):
"""Holds an array of optimizer state."""

pass


class OptVariable(OptState):
"""Holds Variable state."""

source_type: type[Variable]
pass


def _wrap_optimizer_state(opt_state):
def wrap_optimizer_state_fn(x):
if isinstance(x, variables.VariableState):
new_state = x.copy()
new_state.source_type = x.type
new_state.type = OptVariable
return new_state.to_variable()
else:
return OptArray(x)

return jax.tree.map(
wrap_optimizer_state_fn,
opt_state,
is_leaf=lambda x: isinstance(x, variables.VariableState),
)


def _opt_state_variables_to_state(opt_state):
def optimizer_variable_to_state_fn(x):
if isinstance(x, OptVariable):
state = x.to_state()
state.type = x.source_type
del state.source_type
return state
elif isinstance(x, OptArray):
return x.value
else:
raise TypeError(
f'Unexpected type when converting optimizer state: {type(x)}'
)

return jax.tree.map(optimizer_variable_to_state_fn, opt_state)


def _update_opt_state(opt_state, updates):
def optimizer_update_variables(x, update):
if isinstance(x, OptVariable):
if not isinstance(update, VariableState):
raise TypeError(
f'Expected update to be VariableState, got {type(update)}'
)
x.raw_value = update.value
elif isinstance(x, OptArray):
if isinstance(update, VariableState):
raise TypeError(
f'Expected update to not to be a VariableState, got {update}'
)
x.raw_value = update
else:
raise TypeError(
f'Unexpected type when updating optimizer state: {type(x)}'
)

return jax.tree.map(optimizer_update_variables, opt_state, updates)


class Optimizer(Object):
"""Simple train state for the common case with a single Optax optimizer.
Expand Down Expand Up @@ -119,12 +190,9 @@ def __init__(
self.step = OptState(jnp.array(0, dtype=jnp.uint32))
self.model = model
self.tx = tx
self.opt_state = OptState(tx.init(nnx.state(model, wrt)))
self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt)))
self.wrt = wrt

def split(self, *filters: filterlib.Filter):
return graph.split(self, *filters)

def update(self, grads):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the
Expand Down Expand Up @@ -182,12 +250,13 @@ def update(self, grads):
Args:
grads: the gradients derived from ``nnx.grad``.
"""
state = nnx.state(self.model, self.wrt)
params = nnx.state(self.model, self.wrt)
opt_state = _opt_state_variables_to_state(self.opt_state)

updates, new_opt_state = self.tx.update(grads, self.opt_state.value, state)
new_params = optax.apply_updates(state, updates)
updates, new_opt_state = self.tx.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
assert isinstance(new_params, nnx.State)

self.step.value += 1
nnx.update(self.model, new_params)
self.opt_state.value = new_opt_state
_update_opt_state(self.opt_state, new_opt_state)
14 changes: 13 additions & 1 deletion flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import typing as tp
from typing import Any

import jax

from flax import nnx
from flax.nnx import reprlib, tracers
from flax.typing import Missing
Expand Down Expand Up @@ -128,6 +130,7 @@ class Variable(tp.Generic[A], reprlib.Representable):
def __init__(
self,
value: tp.Union[A, VariableMetadata[A]],
*,
set_value_hooks: tp.Union[
SetValueHook[A], tp.Sequence[SetValueHook[A]]
] = (),
Expand Down Expand Up @@ -281,7 +284,7 @@ def copy_from(self, other: Variable[A]) -> None:
vars_dict.clear()
vars_dict.update(other_vars, _trace_state=trace_state)

def copy_from_state(self, variable_state: VariableState[A]):
def update_from_state(self, variable_state: VariableState[A]):
trace_state = self._trace_state
variable_vars = vars(self)
variable_vars.clear()
Expand Down Expand Up @@ -815,6 +818,12 @@ def __init__(
self.value = value
vars(self).update(metadata)

if tp.TYPE_CHECKING:

def __getattr__(self, name: str) -> None: ...
def __setattr__(self, name: str, value: Any) -> None: ...
def __delattr__(self, name: str) -> None: ...

def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('type', self.type.__name__)
Expand Down Expand Up @@ -852,6 +861,9 @@ def to_variable(self) -> Variable[A]:
)
return variables

def copy(self: VariableState[A]) -> VariableState[A]:
return jax.tree.map(lambda x: x, self)

def get_metadata(self) -> dict[str, tp.Any]:
metadata = vars(self).copy()
del metadata['type']
Expand Down
48 changes: 42 additions & 6 deletions tests/nnx/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,45 @@ def test_split_merge(self, module_cls):
x = jax.random.normal(jax.random.key(0), (1, 2))
model = module_cls(2, 4, rngs=nnx.Rngs(0))
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)
out = state.model(x)
graphdef, state = state.split()
state = nnx.merge(graphdef, state)
np.testing.assert_allclose(out, state.model(x))
optimizer = nnx.Optimizer(model, tx)
out = optimizer.model(x)
graphdef, optimizer = nnx.split(optimizer)
optimizer = nnx.merge(graphdef, optimizer)
np.testing.assert_allclose(out, optimizer.model(x))

def test_update(self):
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adamw(0.1))

def loss_fn(model):
params = nnx.state(model)
loss = sum(jnp.sum(x**2) for x in jax.tree.leaves(params))
return loss

grads = nnx.grad(loss_fn)(model)
optimizer.update(grads)

def test_sharding_propagation(self):
model = nnx.Linear(
2,
3,
rngs=nnx.Rngs(0),
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(),
sharding=('a', 'b'),
),
use_bias=False,
)
optimizer = nnx.Optimizer(model, optax.adamw(0.1))

state = nnx.state(optimizer)
partition_spec = nnx.get_partition_spec(state)

self.assertEqual(state.opt_state[0].mu.kernel.sharding, ('a', 'b'))
self.assertEqual(
partition_spec.opt_state[0].mu.kernel.value,
jax.sharding.PartitionSpec('a', 'b'),
)

@parameterized.product(
module_cls=[nnx.Linear, Model],
Expand Down Expand Up @@ -75,7 +109,9 @@ def jax_jit_train_step(graphdef, state, x, y):
state.update(grads)
return state.split()

graphdef, state = jit_decorator(jax_jit_train_step)(*state.split(), x, y)
graphdef, state = jit_decorator(jax_jit_train_step)(
*nnx.split(state), x, y
)
state = nnx.merge(graphdef, state)
new_loss = loss_fn(*nnx.split(state.model), x, y)

Expand Down

0 comments on commit ffe04bd

Please sign in to comment.