diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py index eaa58a051c..6ed7660cdf 100644 --- a/flax/nnx/reprlib.py +++ b/flax/nnx/reprlib.py @@ -16,7 +16,6 @@ import dataclasses import threading import typing as tp -from abc import ABC, abstractmethod A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -48,10 +47,9 @@ class Attr: end: str = '' -class Representable(ABC): +class Representable: __slots__ = () - @abstractmethod def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: raise NotImplementedError diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 91d6c861d9..7af20cdb73 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -120,136 +120,86 @@ class Variable(tp.Generic[A], reprlib.Representable): """ raw_value: A - set_value_hooks: tuple[SetValueHook[A], ...] - get_value_hooks: tuple[GetValueHook[A], ...] - create_value_hooks: tuple[CreateValueHook[A], ...] - add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] - remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] _trace_state: tracers.TraceState + _var_metadata: dict[str, tp.Any] def __init__( self, value: tp.Union[A, VariableMetadata[A]], - *, - set_value_hooks: tp.Union[ - SetValueHook[A], tp.Sequence[SetValueHook[A]] - ] = (), - get_value_hooks: tp.Union[ - GetValueHook[A], tp.Sequence[GetValueHook[A]] - ] = (), - create_value_hooks: tp.Union[ - CreateValueHook[A], tp.Sequence[CreateValueHook[A]] - ] = (), - add_axis_hooks: tp.Union[ - AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] - ] = (), - remove_axis_hooks: tp.Union[ - RemoveAxisHook[Variable[A]], - tp.Sequence[RemoveAxisHook[Variable[A]]], - ] = (), **metadata: tp.Any, ): - vars(self)['_trace_state'] = tracers.TraceState() - if callable(set_value_hooks): - set_value_hooks = (set_value_hooks,) - else: - set_value_hooks = tuple(set_value_hooks) - - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks,) - else: - get_value_hooks = tuple(get_value_hooks) - - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks,) - else: - create_value_hooks = tuple(create_value_hooks) - - if callable(add_axis_hooks): - add_axis_hooks = (add_axis_hooks,) - else: - add_axis_hooks = tuple(add_axis_hooks) - - if callable(remove_axis_hooks): - remove_axis_hooks = (remove_axis_hooks,) - else: - remove_axis_hooks = tuple(remove_axis_hooks) + type_vars = vars(type(self)) + vars_self = vars(self) + vars_self['_trace_state'] = tracers.TraceState() if isinstance(value, VariableMetadata): - value_metadata = dict(value.metadata) - if value.set_value_hooks: - set_value_hooks = set_value_hooks + value.set_value_hooks - if value.get_value_hooks: - get_value_hooks = get_value_hooks + value.get_value_hooks - if value.create_value_hooks: - create_value_hooks = create_value_hooks + value.create_value_hooks - if value.add_axis_hooks: - add_axis_hooks = add_axis_hooks + value.add_axis_hooks - if value.remove_axis_hooks: - remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks - - metadata.update(value_metadata) + metadata.update(value.metadata) value = tp.cast(A, value.raw_value) - self.raw_value = value + object.__setattr__(self, 'raw_value', value) - if 'on_get_value' in vars(type(self)): - on_get_value = getattr(type(self), 'on_get_value') - if on_get_value not in get_value_hooks: - get_value_hooks = (on_get_value, *get_value_hooks) + if 'on_get_value' in type_vars and 'on_get_value' not in metadata: + metadata['get_value'] = getattr(type(self), 'on_get_value') - if 'on_set_value' in vars(type(self)): - on_set_value = getattr(type(self), 'on_set_value') - if on_set_value not in set_value_hooks: - set_value_hooks = (on_set_value, *set_value_hooks) + if 'on_set_value' in type_vars and 'on_set_value' not in metadata: + metadata['set_value'] = getattr(type(self), 'on_set_value') - if 'on_create_value' in vars(type(self)): - on_create_value = getattr(type(self), 'on_create_value') - if on_create_value not in create_value_hooks: - create_value_hooks = (on_create_value, *create_value_hooks) + if 'on_create_value' in type_vars and 'on_create_value' not in metadata: + metadata['create_value'] = getattr(type(self), 'on_create_value') - if 'on_add_axis' in vars(type(self)): - on_add_axis = getattr(type(self), 'on_add_axis') - if on_add_axis not in add_axis_hooks: - add_axis_hooks = (on_add_axis, *add_axis_hooks) + if 'on_add_axis' in type_vars and 'on_add_axis' not in metadata: + metadata['add_axis'] = getattr(type(self), 'on_add_axis') - if 'on_remove_axis' in vars(type(self)): - on_remove_axis = getattr(type(self), 'on_remove_axis') - if on_remove_axis not in remove_axis_hooks: - remove_axis_hooks = (on_remove_axis, *remove_axis_hooks) - - self.get_value_hooks = get_value_hooks - self.set_value_hooks = set_value_hooks - self.create_value_hooks = create_value_hooks - self.add_axis_hooks = add_axis_hooks - self.remove_axis_hooks = remove_axis_hooks - vars(self).update(metadata) + if 'on_remove_axis' in type_vars and 'on_remove_axis' not in metadata: + metadata['remove_axis'] = getattr(type(self), 'on_remove_axis') + vars_self['_var_metadata'] = metadata # run create_value hooks - self.raw_value = self.create_value(self.raw_value) + vars_self['raw_value'] = self.create_value(self.raw_value) + + def __getattr__(self, name: str) -> tp.Any: + if name in vars(self)['_var_metadata']: + return self._var_metadata[name] + return getattr(self.value, name) - if not tp.TYPE_CHECKING: + def __setattr__(self, name: str, value: tp.Any): + if not self._trace_state.is_valid(): + raise errors.TraceContextError( + f'Cannot mutate {type(self).__name__} from a different trace level' + ) - def __setattr__(self, name: str, value: Any) -> None: - return self._setattr(name, value) + if ( + name == 'value' + or name == 'raw_value' + or name == '_var_metadata' + or name == '_trace_state' + ): + object.__setattr__(self, name, value) + else: + self._var_metadata[name] = value - def _setattr(self, name: str, value: tp.Any): + def __delattr__(self, name: str): if not self._trace_state.is_valid(): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) - object.__setattr__(self, name, value) + if ( + name == 'value' + or name == 'raw_value' + or name == '_var_metadata' + or name == '_trace_state' + ): + object.__delattr__(self, name) + else: + del self._var_metadata[name] @classmethod def state(cls, value: A, **metadata) -> VariableState[A]: return cls(value, **metadata).to_state() def get_metadata(self): - metadata = vars(self).copy() - del metadata['raw_value'] - del metadata['_trace_state'] - return metadata + return self._var_metadata def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): @@ -259,29 +209,20 @@ def copy_from(self, other: Variable[A]) -> None: ) if self is other: return - trace_state = self._trace_state - vars_dict = vars(self) - other_vars = vars(other).copy() - del other_vars['_trace_state'] - vars_dict.clear() - vars_dict.update(other_vars, _trace_state=trace_state) + self.raw_value = other.raw_value + self._var_metadata.clear() + self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: VariableState[A]): - trace_state = self._trace_state - variable_vars = vars(self) - variable_vars.clear() - variable_vars.update( - variable_state.get_metadata(), - raw_value=variable_state.value, - _trace_state=trace_state, - ) + vars_self = vars(self) + vars_self['raw_value'] = variable_state.value + vars_self['_var_metadata'] = variable_state.get_metadata().copy() @property def value(self) -> A: value = self.raw_value - if self.get_value_hooks: - for hook in self.get_value_hooks: - value = hook(self, value) + if 'on_get_value' in self._var_metadata: + value = self._var_metadata['on_get_value'](self, value) return value @value.setter @@ -290,23 +231,22 @@ def value(self, value: A): raise ValueError( 'Cannot set value to a Variable, ' 'use `copy_from` method instead' ) - if self.set_value_hooks: - for hook in self.set_value_hooks: - value = hook(self, value) - self.raw_value = value + if 'on_set_value' in self._var_metadata: + value = self._var_metadata['on_set_value'](self, value) + vars(self)['raw_value'] = value def create_value(self, value: A): - for hook in self.create_value_hooks: - value = hook(self, value) + if 'on_create_value' in self._var_metadata: + value = self._var_metadata['on_create_value'](self, value) return value def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_add_axis' in self._var_metadata: + self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_remove_axis' in self._var_metadata: + self._var_metadata['on_remove_axis'](self, axis_index, axis_name) def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @@ -344,26 +284,27 @@ def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]: return value # get and update attributes - attributes = vars(self).copy() - attributes.update(**kwargs) # return new instance with updated attributes obj = object.__new__(type(self)) - vars(obj).update(attributes) + object.__setattr__(obj, '_trace_state', self._trace_state) + object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value')) + object.__setattr__(obj, '_var_metadata', self.get_metadata()) + obj._var_metadata.update(kwargs) return obj @classmethod def from_metadata(cls, value: A, attributes: tp.Mapping[str, tp.Any]): obj = object.__new__(cls) - vars(obj).update( - attributes, raw_value=value, _trace_state=tracers.TraceState() - ) + object.__setattr__(obj, '_trace_state', tracers.TraceState()) + object.__setattr__(obj, 'raw_value', value) + object.__setattr__(obj, '_var_metadata', attributes) return obj def copy(self: Variable[A]) -> Variable[A]: obj = object.__new__(type(self)) - attributes = vars(self).copy() - attributes['_trace_state'] = tracers.TraceState() - vars(obj).update(attributes) + object.__setattr__(obj, '_trace_state', self._trace_state) + object.__setattr__(obj, 'raw_value', self.raw_value) + object.__setattr__(obj, '_var_metadata', self.get_metadata().copy()) return obj def to_state(self: Variable[A]) -> VariableState[A]: @@ -372,23 +313,14 @@ def to_state(self: Variable[A]) -> VariableState[A]: def __nnx_repr__(self): yield reprlib.Object(type=type(self)) - for name, value in vars(self).items(): - if name == 'raw_value': - name = 'value' - if name.endswith('_hooks') or name == '_trace_state': - continue + yield reprlib.Attr('value', self.raw_value) + for name, value in self._var_metadata.items(): yield reprlib.Attr(name, repr(value)) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] - children = {} - for name, value in vars(self).items(): - if name == 'raw_value': - name = 'value' - if name.endswith('_hooks') or name == '_trace_state': - continue - children[name] = value + children = {'value': self.raw_value, **self._var_metadata} return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes=children, @@ -426,10 +358,6 @@ def __setstate__(self, state): # -------------------------------------------- # proxy methods # -------------------------------------------- - # NOTE: we dont override __setattr__ to avoid cases where - # you need to set an attribute on the variable instance - def __getattr__(self, name: str) -> tp.Any: - return getattr(self.value, name) def __getitem__(self, key) -> tp.Any: return self.value[key] # type: ignore @@ -803,39 +731,51 @@ class Intermediate(Variable[A]): class VariableState(tp.Generic[A], reprlib.Representable): + __slots__ = ('type', 'value', '_var_metadata') + type: type[Variable[A]] + value: A + _var_metadata: dict[str, tp.Any] + def __init__( self, - type: type[Variable[tp.Any]], + type: type[Variable[A]], # type: ignore [valid-type] value: A, **metadata, ): - self.type = type - self.value = value - vars(self).update(metadata) - - if tp.TYPE_CHECKING: + object.__setattr__(self, 'type', type) + object.__setattr__(self, 'value', value) + object.__setattr__(self, '_var_metadata', metadata) + + def __getattr__(self, name: str) -> None: + var_metadata = object.__getattribute__(self, '_var_metadata') + if name not in var_metadata: + raise AttributeError(f"'VariableState' object has no attribute '{name}'") + return var_metadata[name] + + def __setattr__(self, name: str, value: Any) -> None: + if name == 'type' or name == 'value' or name == '_var_metadata': + object.__setattr__(self, name, value) + else: + self._var_metadata[name] = value - def __getattr__(self, name: str) -> None: ... - def __setattr__(self, name: str, value: Any) -> None: ... - def __delattr__(self, name: str) -> None: ... + def __delattr__(self, name: str) -> None: + if name == 'type' or name == 'value' or name == '_var_metadata': + object.__delattr__(self, name) + else: + del self._var_metadata[name] def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) + yield reprlib.Attr('value', self.value) - for name, value in vars(self).items(): - if name == 'type' or name.endswith('_hooks'): - continue + for name, value in self._var_metadata.items(): yield reprlib.Attr(name, repr(value)) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] - children = {'type': self.type} - for name, value in vars(self).items(): - if name == 'type' or name.endswith('_hooks'): - continue - children[name] = value + children = {'type': self.type, 'value': self.value, **self._var_metadata} return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes=children, @@ -849,29 +789,25 @@ def replace(self, value: B) -> VariableState[B]: def to_variable(self) -> Variable[A]: # we use object.__new__ to avoid calling __init__ and bypass the # __init__ logic which should not be called twice - metadata = self.get_metadata() - variables = object.__new__(self.type) - vars(variables).update( - metadata, raw_value=self.value, _trace_state=tracers.TraceState() - ) - return variables + variable = object.__new__(self.type) + object.__setattr__(variable, '_trace_state', tracers.TraceState()) + object.__setattr__(variable, 'raw_value', self.value) + object.__setattr__(variable, '_var_metadata', self.get_metadata().copy()) + return variable def copy(self: VariableState[A]) -> VariableState[A]: return jax.tree.map(lambda x: x, self) def get_metadata(self) -> dict[str, tp.Any]: - metadata = vars(self).copy() - del metadata['type'] - del metadata['value'] - return metadata + return self._var_metadata def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_add_axis' in self._var_metadata: + self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_remove_axis' in self._var_metadata: + self._var_metadata['on_remove_axis'](self, axis_index, axis_name) def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): diff --git a/tests/nnx/containers_test.py b/tests/nnx/containers_test.py index 97785e7658..92345abc66 100644 --- a/tests/nnx/containers_test.py +++ b/tests/nnx/containers_test.py @@ -21,15 +21,15 @@ class TestContainers(absltest.TestCase): def test_unbox(self): x = nnx.Param( 1, - get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + on_get_value=lambda c, x: x + 3, # type: ignore ) assert x.value == 4 - def test_box(self): + def test_on_set_value(self): x: nnx.Param[int] = nnx.Param( 1, # type: ignore - set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + on_set_value=lambda c, x: x + 7, # type: ignore ) x.value = 5 @@ -38,9 +38,7 @@ def test_box(self): def test_module_unbox(self): class Foo(nnx.Module): def __init__(self) -> None: - self.x = nnx.Param( - 1, get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] - ) + self.x = nnx.Param(1, on_get_value=lambda c, x: x + 3) module = Foo() @@ -51,7 +49,8 @@ def test_module_box(self): class Foo(nnx.Module): def __init__(self) -> None: self.x = nnx.Param( - 1, set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] + 1, + on_set_value=lambda c, x: x + 7, # type: ignore ) module = Foo() diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 498ce3defe..ce65186dd2 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -40,7 +40,7 @@ def __setitem__(self, idx, value): class Dict(nnx.Module): def __init__(self, *args, **kwargs): - self.items = dict(*args, **kwargs) + vars(self)['items'] = dict(*args, **kwargs) def __getitem__(self, key): return vars(self)['items'][key] @@ -48,6 +48,12 @@ def __getitem__(self, key): def __setitem__(self, key, value): vars(self)['items'][key] = value + def __setattr__(self, key, value): + if key == 'items': + object.__setattr__(self, key, value) + else: + vars(self)['items'][key] = value + def __getattr__(self, key): attrs = vars(self) if 'items' not in attrs: diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 828ee56816..2372fbad6f 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -112,19 +112,20 @@ class MLP(nnx.Module): ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear( - 3, - 3, - kernel_init=nnx.with_metadata( - nnx.initializers.lecun_normal(), sharding=('din', 'dout'), - add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)), - ), - bias_init=nnx.with_metadata( - nnx.initializers.zeros_init(), # no sharding annotation here! - add_axis_hooks=lambda _, idx, name: badds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)), - ), - rngs=rngs, + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), + sharding=('din', 'dout'), + on_add_axis=lambda _, idx, name: kadds.append((idx, name)), + on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros_init(), # no sharding annotation here! + on_add_axis=lambda _, idx, name: badds.append((idx, name)), + on_remove_axis=lambda _, idx, name: bremoves.append((idx, name)), + ), + rngs=rngs, ) @nnx.scan( diff --git a/uv.lock b/uv.lock index 0d68a86b72..a30155113e 100644 --- a/uv.lock +++ b/uv.lock @@ -773,7 +773,7 @@ wheels = [ [[package]] name = "flax" -version = "0.10.1" +version = "0.10.2" source = { editable = "." } dependencies = [ { name = "jax" },