Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add FlatState #4410

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "068208fc",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -280,7 +280,7 @@
" predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
" flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
"\n",
" for path, value in state.flat_state().items():\n",
" for path, value in state.flat_state():\n",
" for i, predicate in enumerate(predicates):\n",
" if predicate(path, value):\n",
" flat_states[i][path] = value\n",
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def split(node, *filters):
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state().items():
for path, value in state.flat_state():
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
Expand Down
2 changes: 1 addition & 1 deletion examples/gemma/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]:

mdl: M = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
state = state.flat_state()
state = dict(state.flat_state())
for path, val in flax.traverse_util.flatten_dict(variables).items():
mapped_path = map_key_fn(path)
if mapped_path not in state:
Expand Down
4 changes: 2 additions & 2 deletions examples/lm1b_nnx/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def transfer_params(
params_linen: dict[str, Any],
):
rules = dataclasses.asdict(config.axis_rules)
flat_params_nnx = params_nnx.flat_state()
flat_params_nnx = dict(params_nnx.flat_state())
flat_params_linen = nnx.traversals.flatten_mapping(params_linen, sep='/')

def apply_rules(names: tuple[str, ...]):
Expand Down Expand Up @@ -163,7 +163,7 @@ def transfer_cache(
cache_nnx: nnx.State,
cache_linen: dict[str, Any],
):
flat_cache_nnx = cache_nnx.flat_state()
flat_cache_nnx = dict(cache_nnx.flat_state())
flat_cache_linen = nnx.traversals.flatten_mapping(cache_linen, sep='/')

def copy_var(nnx_name: str, linen_name: str):
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from . import traversals as traversals
71 changes: 37 additions & 34 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.nnx.statelib import FlatState, State
from flax.nnx.statelib import State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
Expand Down Expand Up @@ -110,15 +110,12 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
pop_key: tp.Callable[[Node, Key], Leaf]
create_empty: tp.Callable[[AuxData], Node]
clear: tp.Callable[[Node], None]

def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]):
for key, value in items:
self.set_key(node, key, value)
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None]


@dataclasses.dataclass(frozen=True, slots=True)
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]


NodeImpl = tp.Union[
Expand All @@ -137,6 +134,7 @@ def register_graph_node_type(
pop_key: tp.Callable[[Node, Key], Leaf],
create_empty: tp.Callable[[AuxData], Node],
clear: tp.Callable[[Node], None],
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None],
):
if type in GRAPH_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')
Expand All @@ -148,12 +146,13 @@ def register_graph_node_type(
pop_key=pop_key,
create_empty=create_empty,
clear=clear,
init=init,
)

def register_pytree_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node],
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node],
):
if type in PYTREE_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')
Expand Down Expand Up @@ -202,8 +201,8 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]):
self._mapping = dict(mapping)
def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True):
self._mapping = dict(mapping) if copy else mapping

def __contains__(self, key: object) -> bool:
return key in self._mapping
Expand Down Expand Up @@ -401,15 +400,15 @@ def flatten(
"""
if ref_index is None:
ref_index = RefMap()
flat_state: dict[PathParts, StateLeaf] = {}
flat_state: list[tuple[PathParts, StateLeaf]] = []
graphdef = _graph_flatten((), ref_index, flat_state, node)
return graphdef, GraphState.from_flat_path(flat_state)


def _graph_flatten(
path: PathParts,
ref_index: RefMap[tp.Any, Index],
flat_state: dict[PathParts, StateLeaf],
flat_state: list[tuple[PathParts, StateLeaf]],
node: Node,
) -> NodeDef[Node] | NodeRef:
if not is_node(node):
Expand Down Expand Up @@ -441,10 +440,10 @@ def _graph_flatten(
LeafAttribute(key, NodeRef(type(value), ref_index[value]))
)
else:
flat_state[(*path, key)] = value.to_state()
flat_state.append(((*path, key), value.to_state()))
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, HashableMapping(value.get_metadata())
type(value), variable_index, HashableMapping(value._var_metadata)
)
attributes.append(LeafAttribute(key, variabledef))
else:
Expand Down Expand Up @@ -528,7 +527,7 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(nodedef.type)

def _get_children():
children: dict[Key, NodeLeaf | Node] = {}
children: list[tuple[Key, NodeLeaf | Node]] = []
state_keys: set = set(state.keys())

# for every key in attributes there are 6 possible cases:
Expand All @@ -539,28 +538,29 @@ def _get_children():
if key not in state:
# if key is not present create an empty types
if type(attribute) is StaticAttribute:
children[key] = attribute.value
children.append((key, attribute.value))
elif type(attribute) is SubGraphAttribute:
# if the key is a subgraph we create an empty node
subgraphdef = attribute.value
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
children[key] = index_ref[subgraphdef.index]
children.append((key, index_ref[subgraphdef.index]))
else:
# create a node from an empty state, reasoning:
# * its a node with no state
# * its a node with state but only through references of already
# created nodes
substate = {}
children[key] = _graph_unflatten(
subnode = _graph_unflatten(
subgraphdef, substate, index_ref, index_ref_cache
)
children.append((key, subnode))
elif type(attribute) is LeafAttribute:
variabledef = attribute.value
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[variabledef.index]
children.append((key, index_ref[variabledef.index]))
else:
# key for a variable is missing, raise an error
raise ValueError(
Expand All @@ -587,19 +587,20 @@ def _get_children():
subgraphdef = attribute.value

if isinstance(subgraphdef, NodeRef):
children[key] = index_ref[subgraphdef.index]
children.append((key, index_ref[subgraphdef.index]))
else:
children[key] = _graph_unflatten(
subnode = _graph_unflatten(
subgraphdef, value, index_ref, index_ref_cache
)
children.append((key, subnode))

elif type(attribute) is LeafAttribute:
variabledef = attribute.value

if variabledef.index in index_ref:
# add an existing variable
assert isinstance(variabledef, NodeRef)
children[key] = index_ref[variabledef.index]
children.append((key, index_ref[variabledef.index]))
else:
# its a unseen variable, create a new one
assert isinstance(variabledef, VariableDef)
Expand All @@ -626,7 +627,7 @@ def _get_children():
variable = variabledef.type.from_metadata(
value, variabledef.metadata
)
children[key] = variable
children.append((key, variable))
index_ref[variabledef.index] = variable
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
Expand All @@ -651,13 +652,11 @@ def _get_children():
else:
node = node_impl.create_empty(nodedef.metadata)
index_ref[nodedef.index] = node
children = _get_children()
node_impl.init(node, tuple(children.items()))
node_impl.init(node, _get_children())
else:
# if the node type does not support the creation of an empty object it means
# that it cannot reference itself, so we can create its children first
children = _get_children()
node = node_impl.unflatten(tuple(children.items()), nodedef.metadata)
node = node_impl.unflatten(_get_children(), nodedef.metadata)

return node

Expand All @@ -669,7 +668,9 @@ def graph_pop(
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates)
flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple(
{} for _ in predicates
)
_graph_pop(node, id_to_index, path_parts, flat_states, predicates)
return tuple(
GraphState.from_flat_path(flat_state) for flat_state in flat_states
Expand All @@ -680,7 +681,7 @@ def _graph_pop(
node: tp.Any,
id_to_index: dict[int, Index],
path_parts: PathParts,
flat_states: tuple[FlatState[StateLeaf], ...],
flat_states: tuple[dict[PathParts, StateLeaf], ...],
predicates: tuple[filterlib.Predicate, ...],
) -> None:
if not is_node(node):
Expand Down Expand Up @@ -816,7 +817,7 @@ def split(
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

return graphdef, *states
Expand Down Expand Up @@ -1006,7 +1007,7 @@ def split(
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

self.flatten_end(ref_index)
Expand Down Expand Up @@ -1570,7 +1571,9 @@ def pop(
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates)
flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple(
{} for _ in predicates
)
_graph_pop(
node=node,
id_to_index=id_to_index,
Expand Down Expand Up @@ -1786,8 +1789,8 @@ def is_pytree_node(x: tp.Any) -> bool:
# known non-pytree types
elif isinstance(x, Variable):
return False
# knon pytree types
elif isinstance(x, (VariableState, State)):
# known pytree types
elif type(x) is VariableState or type(x) is State:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not just a performance optimization, right? You no longer allow subclasses of VariableState or State.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. These should be treated more like Rust enums.

return True
else:
return not jax.tree_util.all_leaves((x,))
Expand Down Expand Up @@ -1829,7 +1832,7 @@ def _unflatten_pytree(
PYTREE_NODE_IMPL = PytreeNodeImpl(
type=GenericPytree,
flatten=_flatten_pytree,
unflatten=_unflatten_pytree,
unflatten=_unflatten_pytree, # type: ignore
)

# common pytrees
Expand Down
21 changes: 11 additions & 10 deletions flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from flax.nnx import graph
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key
from flax import errors

G = tp.TypeVar('G', bound='Object')
Expand Down Expand Up @@ -109,10 +108,11 @@ def __init_subclass__(cls) -> None:
graph.register_graph_node_type(
type=cls,
flatten=cls._graph_node_flatten,
set_key=cls._graph_node_set_key,
pop_key=cls._graph_node_pop_key,
set_key=cls._graph_node_set_key, # type: ignore
pop_key=cls._graph_node_pop_key, # type: ignore
create_empty=cls._graph_node_create_empty,
clear=cls._graph_node_clear,
init=cls._graph_node_init, # type: ignore
)

if not tp.TYPE_CHECKING:
Expand Down Expand Up @@ -189,14 +189,12 @@ def __treescope_repr__(self, path, subtree_renderer):

# Graph Definition
def _graph_node_flatten(self):
nodes = sorted(
(key, value)
for key, value in vars(self).items()
if key != '_object__state'
)
nodes = vars(self).copy()
del nodes['_object__state']
nodes = sorted(nodes.items())
return nodes, (type(self), self._object__state._initializing)

def _graph_node_set_key(self, key: Key, value: tp.Any):
def _graph_node_set_key(self, key: str, value: tp.Any):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
elif (
Expand All @@ -208,7 +206,7 @@ def _graph_node_set_key(self, key: Key, value: tp.Any):
else:
setattr(self, key, value)

def _graph_node_pop_key(self, key: Key):
def _graph_node_pop_key(self, key: str):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
return vars(self).pop(key)
Expand All @@ -225,3 +223,6 @@ def _graph_node_clear(self):
module_vars = vars(self)
module_vars.clear()
module_vars['_object__state'] = module_state

def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
vars(self).update(attributes)
Loading
Loading