Skip to content

Commit

Permalink
Merge pull request #4170 from google:nnx-variable-proxy
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671172549
  • Loading branch information
Flax Authors committed Sep 5, 2024
2 parents bc594fb + 6d7ad15 commit aded9ac
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 110 deletions.
2 changes: 0 additions & 2 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,9 @@
from .nnx.transforms.transforms import eval_shape as eval_shape
from .nnx.transforms.transforms import cond as cond
from .nnx.transforms.iteration import StateAxes as StateAxes
from .nnx.variables import EMPTY as EMPTY
from .nnx.variables import A as A
from .nnx.variables import BatchStat as BatchStat
from .nnx.variables import Cache as Cache
from .nnx.variables import Empty as Empty
from .nnx.variables import Intermediate as Intermediate
from .nnx.variables import Variable as Variable
from .nnx.variables import VariableState as VariableState
Expand Down
187 changes: 86 additions & 101 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from flax import nnx
from flax.nnx.nnx import reprlib, tracers
from flax.typing import Missing
import jax.tree_util as jtu

A = tp.TypeVar('A')
Expand All @@ -53,33 +54,6 @@
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}


class Empty:
def __repr__(self):
return 'Empty'

def __eq__(self, other):
return isinstance(other, Empty)

def __hash__(self):
return hash(Empty)


jtu.register_pytree_node(
Empty,
lambda empty: ((), None),
lambda _0, _1: EMPTY,
)

EMPTY: Empty = Empty()


class _Missing:
pass


MISSING = _Missing()


@dataclasses.dataclass
class VariableMetadata(tp.Generic[A]):
raw_value: A
Expand Down Expand Up @@ -156,6 +130,7 @@ class Variable(tp.Generic[A], reprlib.Representable):
}
})
"""

raw_value: A
set_value_hooks: tuple[SetValueHook[A], ...]
get_value_hooks: tuple[GetValueHook[A], ...]
Expand Down Expand Up @@ -251,32 +226,33 @@ def __init__(
metadata.update(value_metadata)
value = tp.cast(A, value.raw_value)

if hasattr(self, 'on_get_value'):
self.raw_value = value

if 'on_get_value' in vars(type(self)):
on_get_value = getattr(type(self), 'on_get_value')
if on_get_value not in get_value_hooks:
get_value_hooks = (on_get_value, *get_value_hooks)

if hasattr(self, 'on_set_value'):
if 'on_set_value' in vars(type(self)):
on_set_value = getattr(type(self), 'on_set_value')
if on_set_value not in set_value_hooks:
set_value_hooks = (on_set_value, *set_value_hooks)

if hasattr(self, 'on_create_value'):
if 'on_create_value' in vars(type(self)):
on_create_value = getattr(type(self), 'on_create_value')
if on_create_value not in create_value_hooks:
create_value_hooks = (on_create_value, *create_value_hooks)

if hasattr(self, 'on_add_axis'):
if 'on_add_axis' in vars(type(self)):
on_add_axis = getattr(type(self), 'on_add_axis')
if on_add_axis not in add_axis_hooks:
add_axis_hooks = (on_add_axis, *add_axis_hooks)

if hasattr(self, 'on_remove_axis'):
if 'on_remove_axis' in vars(type(self)):
on_remove_axis = getattr(type(self), 'on_remove_axis')
if on_remove_axis not in remove_axis_hooks:
remove_axis_hooks = (on_remove_axis, *remove_axis_hooks)

self.raw_value = value
self.get_value_hooks = get_value_hooks
self.set_value_hooks = set_value_hooks
self.create_value_hooks = create_value_hooks
Expand All @@ -287,10 +263,7 @@ def __init__(
# run create_value hooks
self.raw_value = self.create_value(self.raw_value)

if tp.TYPE_CHECKING:

def __getattr__(self, name: str) -> tp.Any: ...
else:
if not tp.TYPE_CHECKING:

def __setattr__(self, name: str, value: Any) -> None:
return self._setattr(name, value)
Expand Down Expand Up @@ -373,8 +346,8 @@ def replace(self, value: B, **kwargs) -> Variable[B]: ...
@tp.overload
def replace(self, **kwargs) -> Variable[A]: ...

def replace(self, value: tp.Any = MISSING, **kwargs) -> Variable[tp.Any]:
if value is not MISSING:
def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]:
if value is not Missing:
kwargs['raw_value'] = value

# rename `value` to `raw_value`
Expand Down Expand Up @@ -431,6 +404,7 @@ def __nnx_repr__(self):

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

children = {}
for name, value in vars(self).items():
if name == 'raw_value':
Expand All @@ -439,46 +413,57 @@ def __treescope_repr__(self, path, subtree_renderer):
continue
children[name] = value
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

# hooks API
if tp.TYPE_CHECKING:

def on_get_value(self, value: A) -> A:
raise NotImplementedError
def on_get_value(self, value: A) -> A: ...

def on_set_value(self, value: A) -> A:
raise NotImplementedError
def on_set_value(self, value: A) -> A: ...

def on_create_value(self, value: A) -> A:
raise NotImplementedError
def on_create_value(self, value: A) -> A: ...

def on_add_axis(self: V, axis_name: AxisName, axis_index: AxisIndex) -> V:
raise NotImplementedError
def on_add_axis(
self: V, axis_name: AxisName, axis_index: AxisIndex
) -> V: ...

def on_remove_axis(
self: V, axis_name: AxisName, axis_index: AxisIndex
) -> V:
raise NotImplementedError

# overloads
@property
def shape(self) -> tuple[int, ...]:
return self.value.shape # type: ignore

@property
def dtype(self) -> Any:
return self.value.dtype # type: ignore
) -> V: ...

def __jax_array__(self):
return self.value

# --------------------------------------------
# proxy methods
# --------------------------------------------
# NOTE: we dont override __setattr__ to avoid cases where
# you need to set an attribute on the variable instance
def __getattr__(self, name: str) -> tp.Any:
return getattr(self.value, name)

def __getitem__(self, key) -> tp.Any:
return self.value.__getitem__(key) # type: ignore
return self.value[key] # type: ignore

def __setitem__(self, key, value) -> None:
self.value[key] = value # type: ignore

def __call__(self, *args, **kwargs) -> tp.Any:
return self.value(*args, **kwargs) # type: ignore

def __len__(self) -> int:
return len(self.value) # type: ignore

def __iter__(self) -> tp.Iterator:
return iter(self.value) # type: ignore

def __contains__(self, item) -> bool:
return item in self.value # type: ignore

def __add__(self, other) -> A:
return self.value.__add__(other) # type: ignore
Expand Down Expand Up @@ -725,7 +710,8 @@ class Param(Variable[A]):
value=(2, 3)
)
})
"""
"""

pass


Expand Down Expand Up @@ -760,41 +746,43 @@ class BatchStat(Variable[A]):
)
})
"""

pass


class Cache(Variable[A]):
"""Autoregressive cache in :class:`MultiHeadAttention`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
... num_heads=2,
... in_features=3,
... qkv_features=6,
... out_features=6,
... decode=True,
... rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
'cache_index': VariableState(
type=Cache,
value=()
),
'cached_key': VariableState(
type=Cache,
value=(1, 2, 3)
),
'cached_value': VariableState(
type=Cache,
value=(1, 2, 3)
)
})
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
... num_heads=2,
... in_features=3,
... qkv_features=6,
... out_features=6,
... decode=True,
... rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
'cache_index': VariableState(
type=Cache,
value=()
),
'cached_key': VariableState(
type=Cache,
value=(1, 2, 3)
),
'cached_value': VariableState(
type=Cache,
value=(1, 2, 3)
)
})
"""

pass


Expand Down Expand Up @@ -826,6 +814,7 @@ class Intermediate(Variable[A]):
)
})
"""

pass


Expand All @@ -840,10 +829,6 @@ def __init__(
self.value = value
vars(self).update(metadata)

if tp.TYPE_CHECKING:

def __getattr__(self, name: str) -> tp.Any: ...

def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('type', self.type.__name__)
Expand All @@ -855,16 +840,17 @@ def __nnx_repr__(self):

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

children = {'type': self.type}
for name, value in vars(self).items():
if name == 'type' or name.endswith('_hooks'):
continue
children[name] = value
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

def replace(self, value: B) -> VariableState[B]:
Expand Down Expand Up @@ -997,4 +983,3 @@ def wrapper(*args):
)

return wrapper # type: ignore

Loading

0 comments on commit aded9ac

Please sign in to comment.