diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index a7acbbc418..9b20d32381 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -44,7 +44,7 @@ def _add_axis(x: tp.Any): sharding.insert(index, axis_name) x.sharding = tuple(sharding) # type: ignore - x.add_axis(axis_name, index) + x.add_axis(index, axis_name) return x return jax.tree.map( @@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any): sharding = list(x.sharding) assert sharding.pop(index) == axis_name x.sharding = tuple(sharding) - x.remove_axis(axis_name, index) + x.remove_axis(index, axis_name) return x return jax.tree.map( diff --git a/flax/nnx/variables.py b/flax/nnx/variables.py index 76805477f5..9417e45337 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variables.py @@ -36,8 +36,8 @@ CreateValueHook = tp.Callable[['Variable[A]', A], A] AxisName = str AxisIndex = int -AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] -RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] +AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] +RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} @@ -150,67 +150,43 @@ def __init__( **metadata: tp.Any, ): vars(self)['_trace_state'] = tracers.TraceState() - if set_value_hooks: - if callable(set_value_hooks): - set_value_hooks = (set_value_hooks,) - else: - set_value_hooks = tuple(set_value_hooks) + if callable(set_value_hooks): + set_value_hooks = (set_value_hooks,) else: - set_value_hooks = () - if get_value_hooks: - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks,) - else: - get_value_hooks = tuple(get_value_hooks) + set_value_hooks = tuple(set_value_hooks) + + if callable(get_value_hooks): + get_value_hooks = (get_value_hooks,) else: - get_value_hooks = () + get_value_hooks = tuple(get_value_hooks) - if create_value_hooks: - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks,) - else: - create_value_hooks = tuple(create_value_hooks) + if callable(create_value_hooks): + create_value_hooks = (create_value_hooks,) else: - create_value_hooks = () + create_value_hooks = tuple(create_value_hooks) - if add_axis_hooks: - if callable(add_axis_hooks): - add_axis_hooks = (add_axis_hooks,) - else: - add_axis_hooks = tuple(add_axis_hooks) + if callable(add_axis_hooks): + add_axis_hooks = (add_axis_hooks,) else: - add_axis_hooks = () + add_axis_hooks = tuple(add_axis_hooks) - if remove_axis_hooks: - if callable(remove_axis_hooks): - remove_axis_hooks = (remove_axis_hooks,) - else: - remove_axis_hooks = tuple(remove_axis_hooks) + if callable(remove_axis_hooks): + remove_axis_hooks = (remove_axis_hooks,) else: - remove_axis_hooks = () + remove_axis_hooks = tuple(remove_axis_hooks) if isinstance(value, VariableMetadata): value_metadata = dict(value.metadata) - if set_value_hooks and value.set_value_hooks: + if value.set_value_hooks: set_value_hooks = set_value_hooks + value.set_value_hooks - elif value.set_value_hooks: - set_value_hooks = value.set_value_hooks - if get_value_hooks and value.get_value_hooks: + if value.get_value_hooks: get_value_hooks = get_value_hooks + value.get_value_hooks - elif value.get_value_hooks: - get_value_hooks = value.get_value_hooks - if create_value_hooks and value.create_value_hooks: + if value.create_value_hooks: create_value_hooks = create_value_hooks + value.create_value_hooks - elif value.create_value_hooks: - create_value_hooks = value.create_value_hooks - if add_axis_hooks and value.add_axis_hooks: + if value.add_axis_hooks: add_axis_hooks = add_axis_hooks + value.add_axis_hooks - elif value.add_axis_hooks: - add_axis_hooks = value.add_axis_hooks - if remove_axis_hooks and value.remove_axis_hooks: + if value.remove_axis_hooks: remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks - elif value.remove_axis_hooks: - remove_axis_hooks = value.remove_axis_hooks metadata.update(value_metadata) value = tp.cast(A, value.raw_value) @@ -318,13 +294,13 @@ def create_value(self, value: A): value = hook(self, value) return value - def add_axis(self, axis_name: AxisName, axis_index: AxisIndex): + def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.add_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) - def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): + def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.remove_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @@ -418,11 +394,11 @@ def on_set_value(self, value: A) -> A: ... def on_create_value(self, value: A) -> A: ... def on_add_axis( - self: V, axis_name: AxisName, axis_index: AxisIndex + self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def on_remove_axis( - self: V, axis_name: AxisName, axis_index: AxisIndex + self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def __jax_array__(self): @@ -870,17 +846,13 @@ def get_metadata(self) -> dict[str, tp.Any]: del metadata['value'] return metadata - def add_axis(self, axis_name: AxisName, axis_index: AxisIndex): - if not hasattr(self, 'add_axis_hooks'): - raise ValueError(f'No add_axis_hooks found for VariableState: {self}') + def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.add_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) - def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): - if not hasattr(self, 'remove_axis_hooks'): - raise ValueError(f'No remove_axis_hooks found for VariableState: {self}') + def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.remove_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 15808e0800..6a202e8135 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -100,6 +100,64 @@ def __call__(self, x): assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col') assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col') + def test_add_remove_axis_in_transform(self): + test = self + kadds, kremoves, badds, bremoves = [], [], [], [] + class MLP(nnx.Module): + + @nnx.split_rngs(splits=5) + @nnx.vmap( + in_axes=(0, 0), + transform_metadata={nnx.PARTITION_NAME: 'layers'}, + ) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), sharding=('din', 'dout'), + add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)), + remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros_init(), # no sharding annotation here! + add_axis_hooks=lambda _, idx, name: badds.append((idx, name)), + remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)), + ), + rngs=rngs, + ) + + @nnx.scan( + in_axes=(0, nnx.Carry), + transform_metadata={nnx.PARTITION_NAME: 'layers'} + ) + def __call__(self, x: jax.Array): + x = self.linear(x) + # test sharding layer axes is not present inside scan + test.assertEqual(self.linear.kernel.shape, (3, 3)) + test.assertEqual(self.linear.kernel.sharding, ('din', 'dout')) + # at least a remove_axis was already called to remove the layer axis + test.assertEqual(kremoves[-1], (0, 'layers')) + test.assertEqual(bremoves[-1], (0, 'layers')) + return x, None + + m = MLP(rngs=nnx.Rngs(0)) + self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.bias.shape, (5, 3)) + # One add_axis called to add the `nnx.vmap` dimension + self.assertEqual(kadds, [(0, 'layers')]) + self.assertEqual(kremoves, []) + self.assertEqual(badds, [(0, 'layers')]) + self.assertEqual(bremoves, []) + + # One remove_axis and one add_axis called when in and out of `nnx.scan` + y = m(jnp.ones((5, 3))) + self.assertEqual(kadds, [(0, 'layers'), (0, 'layers')]) + self.assertEqual(kremoves, [(0, 'layers')]) + self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) + self.assertEqual(bremoves, [(0, 'layers')]) + if __name__ == '__main__': absltest.main()