diff --git a/docs_nnx/guides/filters_guide.ipynb b/docs_nnx/guides/filters_guide.ipynb index ed37ad8731..a4dfabea97 100644 --- a/docs_nnx/guides/filters_guide.ipynb +++ b/docs_nnx/guides/filters_guide.ipynb @@ -248,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "068208fc", "metadata": {}, "outputs": [ @@ -280,7 +280,7 @@ " predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n", " flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n", "\n", - " for path, value in state.flat_state().items():\n", + " for path, value in state.flat_state():\n", " for i, predicate in enumerate(predicates):\n", " if predicate(path, value):\n", " flat_states[i][path] = value\n", diff --git a/docs_nnx/guides/filters_guide.md b/docs_nnx/guides/filters_guide.md index 97ff439ce2..dcd414d76a 100644 --- a/docs_nnx/guides/filters_guide.md +++ b/docs_nnx/guides/filters_guide.md @@ -145,7 +145,7 @@ def split(node, *filters): predicates = [nnx.filterlib.to_predicate(f) for f in filters] flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates] - for path, value in state.flat_state().items(): + for path, value in state.flat_state(): for i, predicate in enumerate(predicates): if predicate(path, value): flat_states[i][path] = value diff --git a/examples/gemma/helpers.py b/examples/gemma/helpers.py index 1797f01271..b9c4195f4a 100644 --- a/examples/gemma/helpers.py +++ b/examples/gemma/helpers.py @@ -62,7 +62,7 @@ def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]: mdl: M = nnx.eval_shape(module_factory) graph_def, state = nnx.split(mdl) - state = state.flat_state() + state = dict(state.flat_state()) for path, val in flax.traverse_util.flatten_dict(variables).items(): mapped_path = map_key_fn(path) if mapped_path not in state: diff --git a/examples/lm1b_nnx/models_test.py b/examples/lm1b_nnx/models_test.py index dd1727c480..d2d0ce03d4 100644 --- a/examples/lm1b_nnx/models_test.py +++ b/examples/lm1b_nnx/models_test.py @@ -79,7 +79,7 @@ def transfer_params( params_linen: dict[str, Any], ): rules = dataclasses.asdict(config.axis_rules) - flat_params_nnx = params_nnx.flat_state() + flat_params_nnx = dict(params_nnx.flat_state()) flat_params_linen = nnx.traversals.flatten_mapping(params_linen, sep='/') def apply_rules(names: tuple[str, ...]): @@ -163,7 +163,7 @@ def transfer_cache( cache_nnx: nnx.State, cache_linen: dict[str, Any], ): - flat_cache_nnx = cache_nnx.flat_state() + flat_cache_nnx = dict(cache_nnx.flat_state()) flat_cache_linen = nnx.traversals.flatten_mapping(cache_linen, sep='/') def copy_var(nnx_name: str, linen_name: str): diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 6a27b090f5..fcb15f0608 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -167,3 +167,4 @@ from .extract import to_tree as to_tree from .extract import from_tree as from_tree from .extract import NodeStates as NodeStates +from . import traversals as traversals \ No newline at end of file diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 2339f5c168..be04b279c8 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -30,7 +30,7 @@ CallableProxy, DelayedAccessor, ) -from flax.nnx.statelib import FlatState, State +from flax.nnx.statelib import State from flax.nnx import variablelib from flax.nnx.variablelib import Variable, VariableState from flax.typing import Key, PathParts, is_key_like @@ -110,15 +110,16 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): pop_key: tp.Callable[[Node, Key], Leaf] create_empty: tp.Callable[[AuxData], Node] clear: tp.Callable[[Node], None] + init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None] - def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]): - for key, value in items: - self.set_key(node, key, value) + # def init(self, node: Node, items: tp.Iterable[tuple[Key, Leaf]]): + # for key, value in items: + # self.set_key(node, key, value) @dataclasses.dataclass(frozen=True, slots=True) class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): - unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node] + unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node] NodeImpl = tp.Union[ @@ -137,6 +138,7 @@ def register_graph_node_type( pop_key: tp.Callable[[Node, Key], Leaf], create_empty: tp.Callable[[AuxData], Node], clear: tp.Callable[[Node], None], + init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None], ): if type in GRAPH_REGISTRY: raise ValueError(f'Node type {type} is already registered.') @@ -148,12 +150,13 @@ def register_graph_node_type( pop_key=pop_key, create_empty=create_empty, clear=clear, + init=init, ) def register_pytree_node_type( type: type, flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], - unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node], + unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node], ): if type in PYTREE_REGISTRY: raise ValueError(f'Node type {type} is already registered.') @@ -202,8 +205,8 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): - def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]): - self._mapping = dict(mapping) + def __init__(self, mapping: tp.Mapping[HA, HB], no_copy: bool = False): + self._mapping = mapping if no_copy else dict(mapping) def __contains__(self, key: object) -> bool: return key in self._mapping @@ -401,7 +404,7 @@ def flatten( """ if ref_index is None: ref_index = RefMap() - flat_state: dict[PathParts, StateLeaf] = {} + flat_state: list[tuple[PathParts, StateLeaf]] = [] graphdef = _graph_flatten((), ref_index, flat_state, node) return graphdef, GraphState.from_flat_path(flat_state) @@ -409,7 +412,7 @@ def flatten( def _graph_flatten( path: PathParts, ref_index: RefMap[tp.Any, Index], - flat_state: dict[PathParts, StateLeaf], + flat_state: list[tuple[PathParts, StateLeaf]], node: Node, ) -> NodeDef[Node] | NodeRef: if not is_node(node): @@ -441,10 +444,10 @@ def _graph_flatten( LeafAttribute(key, NodeRef(type(value), ref_index[value])) ) else: - flat_state[(*path, key)] = value.to_state() + flat_state.append(((*path, key), value.to_state())) variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( - type(value), variable_index, HashableMapping(value.get_metadata()) + type(value), variable_index, HashableMapping(value._var_metadata) ) attributes.append(LeafAttribute(key, variabledef)) else: @@ -528,7 +531,7 @@ def _graph_unflatten( node_impl = get_node_impl_for_type(nodedef.type) def _get_children(): - children: dict[Key, NodeLeaf | Node] = {} + children: list[tuple[Key, NodeLeaf | Node]] = [] state_keys: set = set(state.keys()) # for every key in attributes there are 6 possible cases: @@ -539,28 +542,29 @@ def _get_children(): if key not in state: # if key is not present create an empty types if type(attribute) is StaticAttribute: - children[key] = attribute.value + children.append((key, attribute.value)) elif type(attribute) is SubGraphAttribute: # if the key is a subgraph we create an empty node subgraphdef = attribute.value assert not isinstance(subgraphdef, VariableDef) if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache - children[key] = index_ref[subgraphdef.index] + children.append((key, index_ref[subgraphdef.index])) else: # create a node from an empty state, reasoning: # * its a node with no state # * its a node with state but only through references of already # created nodes substate = {} - children[key] = _graph_unflatten( + subnode = _graph_unflatten( subgraphdef, substate, index_ref, index_ref_cache ) + children.append((key, subnode)) elif type(attribute) is LeafAttribute: variabledef = attribute.value if variabledef.index in index_ref: # variable exists, take it from the cache - children[key] = index_ref[variabledef.index] + children.append((key, index_ref[variabledef.index])) else: # key for a variable is missing, raise an error raise ValueError( @@ -587,11 +591,12 @@ def _get_children(): subgraphdef = attribute.value if isinstance(subgraphdef, NodeRef): - children[key] = index_ref[subgraphdef.index] + children.append((key, index_ref[subgraphdef.index])) else: - children[key] = _graph_unflatten( + subnode = _graph_unflatten( subgraphdef, value, index_ref, index_ref_cache ) + children.append((key, subnode)) elif type(attribute) is LeafAttribute: variabledef = attribute.value @@ -599,7 +604,7 @@ def _get_children(): if variabledef.index in index_ref: # add an existing variable assert isinstance(variabledef, NodeRef) - children[key] = index_ref[variabledef.index] + children.append((key, index_ref[variabledef.index])) else: # its a unseen variable, create a new one assert isinstance(variabledef, VariableDef) @@ -626,7 +631,7 @@ def _get_children(): variable = variabledef.type.from_metadata( value, variabledef.metadata ) - children[key] = variable + children.append((key, variable)) index_ref[variabledef.index] = variable else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') @@ -651,13 +656,11 @@ def _get_children(): else: node = node_impl.create_empty(nodedef.metadata) index_ref[nodedef.index] = node - children = _get_children() - node_impl.init(node, tuple(children.items())) + node_impl.init(node, _get_children()) else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first - children = _get_children() - node = node_impl.unflatten(tuple(children.items()), nodedef.metadata) + node = node_impl.unflatten(_get_children(), nodedef.metadata) return node @@ -669,7 +672,9 @@ def graph_pop( id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates) + flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple( + {} for _ in predicates + ) _graph_pop(node, id_to_index, path_parts, flat_states, predicates) return tuple( GraphState.from_flat_path(flat_state) for flat_state in flat_states @@ -680,7 +685,7 @@ def _graph_pop( node: tp.Any, id_to_index: dict[int, Index], path_parts: PathParts, - flat_states: tuple[FlatState[StateLeaf], ...], + flat_states: tuple[dict[PathParts, StateLeaf], ...], predicates: tuple[filterlib.Predicate, ...], ) -> None: if not is_node(node): @@ -816,7 +821,7 @@ def split( if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True) ) return graphdef, *states @@ -1006,7 +1011,7 @@ def split( if self.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(self.index_ref, ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True) ) self.flatten_end(ref_index) @@ -1570,7 +1575,9 @@ def pop( id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates) + flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple( + {} for _ in predicates + ) _graph_pop( node=node, id_to_index=id_to_index, @@ -1787,7 +1794,7 @@ def is_pytree_node(x: tp.Any) -> bool: elif isinstance(x, Variable): return False # knon pytree types - elif isinstance(x, (VariableState, State)): + elif type(x) is VariableState or type(x) is State: return True else: return not jax.tree_util.all_leaves((x,)) @@ -1829,7 +1836,7 @@ def _unflatten_pytree( PYTREE_NODE_IMPL = PytreeNodeImpl( type=GenericPytree, flatten=_flatten_pytree, - unflatten=_unflatten_pytree, + unflatten=_unflatten_pytree, # type: ignore ) # common pytrees diff --git a/flax/nnx/object.py b/flax/nnx/object.py index c63506fc48..afa41cdb7b 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -30,7 +30,6 @@ ) from flax.nnx import graph from flax.nnx.variablelib import Variable, VariableState -from flax.typing import Key from flax import errors G = tp.TypeVar('G', bound='Object') @@ -109,10 +108,11 @@ def __init_subclass__(cls) -> None: graph.register_graph_node_type( type=cls, flatten=cls._graph_node_flatten, - set_key=cls._graph_node_set_key, - pop_key=cls._graph_node_pop_key, + set_key=cls._graph_node_set_key, # type: ignore + pop_key=cls._graph_node_pop_key, # type: ignore create_empty=cls._graph_node_create_empty, clear=cls._graph_node_clear, + init=cls._graph_node_init, # type: ignore ) if not tp.TYPE_CHECKING: @@ -189,14 +189,12 @@ def __treescope_repr__(self, path, subtree_renderer): # Graph Definition def _graph_node_flatten(self): - nodes = sorted( - (key, value) - for key, value in vars(self).items() - if key != '_object__state' - ) + nodes = vars(self).copy() + del nodes['_object__state'] + nodes = sorted(nodes.items()) return nodes, (type(self), self._object__state._initializing) - def _graph_node_set_key(self, key: Key, value: tp.Any): + def _graph_node_set_key(self, key: str, value: tp.Any): if not isinstance(key, str): raise KeyError(f'Invalid key: {key!r}') elif ( @@ -208,7 +206,7 @@ def _graph_node_set_key(self, key: Key, value: tp.Any): else: setattr(self, key, value) - def _graph_node_pop_key(self, key: Key): + def _graph_node_pop_key(self, key: str): if not isinstance(key, str): raise KeyError(f'Invalid key: {key!r}') return vars(self).pop(key) @@ -225,3 +223,6 @@ def _graph_node_clear(self): module_vars = vars(self) module_vars.clear() module_vars['_object__state'] = module_state + + def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]): + vars(self).update(attributes) \ No newline at end of file diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index df299ea54d..ba58b4cd0b 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pytype: skip-file from __future__ import annotations from collections.abc import MutableMapping @@ -28,7 +29,6 @@ K = tp.TypeVar('K', bound=tp.Hashable) V = tp.TypeVar('V') -FlatState = dict[PathParts, V] ExtractValueFn = tp.Callable[[tp.Any], tp.Any] SetValueFn = tp.Callable[[V, tp.Any], V] @@ -54,6 +54,55 @@ def __treescope_repr__(self, path, subtree_renderer): # Render as the dictionary itself at the same path. return subtree_renderer(children, path=path) +class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence): + _keys: tuple[PathParts, ...] + _values: list[V] + + def __init__(self, items: tp.Iterable[tuple[PathParts, V]]): + keys, values = [], [] + for key, value in items: + keys.append(key) + values.append(value) + self._keys = tuple(keys) + self._values = values + + @tp.overload + def __getitem__(self, index: int) -> tuple[PathParts, V]: ... + @tp.overload + def __getitem__(self, index: slice) -> FlatState[V]: ... + def __getitem__( + self, index: int | slice + ) -> tuple[PathParts, V] | FlatState[V]: + if isinstance(index, int): + return self._keys[index], self._values[index] + return FlatState(zip(self._keys[index], self._values[index])) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]: + return iter(zip(self._keys, self._values)) + + +def _flat_state_pytree_flatten(x: FlatState[V]): + return x._values, x._keys + + +def _flat_state_pytree_unflatten( + keys: tuple[PathParts, ...], values: list[V] +) -> FlatState[V]: + flat_state = object.__new__(FlatState) + flat_state._keys = keys + flat_state._values = values + return flat_state + + +jax.tree_util.register_pytree_node( + FlatState, + _flat_state_pytree_flatten, + _flat_state_pytree_unflatten, +) + class State(MutableMapping[K, V], reprlib.Representable): """A pytree-like structure that contains a ``Mapping`` from hashable and @@ -148,12 +197,14 @@ def __treescope_repr__(self, path, subtree_renderer): def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: flat_state = self.flat_state() - for path, variable_state in flat_state.items(): - flat_state[path] = f(path, variable_state) - return State.from_flat_path(flat_state) + result = [] + for path, variable_state in flat_state: + variable_state = f(path, variable_state) + result.append((path, variable_state)) + return State.from_flat_path(result) def flat_state(self) -> FlatState[V]: - return traversals.flatten_mapping(self._mapping) + return FlatState(traversals.flatten_to_sequence(self._mapping)) @classmethod def from_flat_path( @@ -172,7 +223,7 @@ def to_pure_dict(self, # Works for nnx.Variable and nnx.VariableState if extract_fn is None: extract_fn = lambda x: x.value if hasattr(x, 'value') else x - flat_values = {k: extract_fn(x) for k, x in self.flat_state().items()} + flat_values = {k: extract_fn(x) for k, x in self.flat_state()} return traversals.unflatten_mapping(flat_values) def replace_by_pure_dict(self, @@ -186,7 +237,7 @@ def try_convert_int(x): # Works for nnx.Variable and nnx.VariableState if replace_fn is None: replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v - current_flat = self.flat_state() + current_flat = dict(self.flat_state()) for kp, v in traversals.flatten_mapping(pure_dict).items(): kp = tuple(map(try_convert_int, kp)) if kp not in current_flat: @@ -241,7 +292,7 @@ def split( # type: ignore[misc] One or more ``States`` equal to the number of filters passed. """ filters = (first, *filters) - *states_, rest = _split_state(self, *filters) + *states_, rest = _split_state(self.flat_state(), *filters) if rest: raise ValueError( @@ -254,7 +305,7 @@ def split( # type: ignore[misc] states = states_[0] else: states = tuple(states_) - return states # type: ignore[bad-return-type] + return states # type: ignore @tp.overload def filter( @@ -306,7 +357,7 @@ def filter( Returns: One or more ``States`` equal to the number of filters passed. """ - *states_, _rest = _split_state(self, first, *filters) + *states_, _rest = _split_state(self.flat_state(), first, *filters) assert len(states_) == len(filters) + 1 @@ -316,7 +367,7 @@ def filter( else: states = tuple(states_) - return states # type: ignore[bad-return-type] + return states # type: ignore @staticmethod def merge( @@ -360,7 +411,7 @@ def merge( states = (state, *states) - new_state: FlatState[V] = {} + new_state: dict[PathParts, V] = {} for state in states: new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here @@ -376,8 +427,8 @@ def __sub__(self, other: State[K, V]) -> State[K, V]: if not other: return self - self_flat = self.flat_state() - other_flat = other.flat_state() + self_flat = dict(self.flat_state()) + other_flat = dict(other.flat_state()) diff = {k: v for k, v in self_flat.items() if k not in other_flat} return State.from_flat_path(diff) @@ -404,9 +455,9 @@ def _state_unflatten( def _split_state( - state: State[K, V], + flat_state: FlatState[V], *filters: filterlib.Filter, -) -> tuple[State[K, V], ...]: +) -> tuple[State[PathParts, V], ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] @@ -417,22 +468,20 @@ def _split_state( ) predicates = tuple(map(filterlib.to_predicate, filters)) - flat_state = state.flat_state() - # we have n + 1 states, where n is the number of predicates # the last state is for values that don't match any predicate - flat_states: tuple[FlatState[V], ...] = tuple( - {} for _ in range(len(predicates) + 1) + flat_states: tuple[list[tuple[PathParts, V]], ...] = tuple( + [] for _ in range(len(predicates) + 1) ) - for path, value in flat_state.items(): + for path, value in flat_state: for i, predicate in enumerate(predicates): if predicate(path, value): - flat_states[i][path] = value # type: ignore[index] # mypy is wrong here? + flat_states[i].append((path, value)) # type: ignore[index] # mypy is wrong here? break else: # if we didn't break, set leaf to last state - flat_states[-1][path] = value # type: ignore[index] # mypy is wrong here? + flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here? return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) @@ -440,7 +489,7 @@ def _split_state( def create_path_filters(state: State): flat_state = state.flat_state() value_paths: dict[tp.Any, set[PathParts]] = {} - for path, value in flat_state.items(): + for path, value in flat_state: if isinstance(value, (variablelib.Variable, variablelib.VariableState)): value = value.value value_paths.setdefault(value, set()).add(path) diff --git a/flax/nnx/traversals.py b/flax/nnx/traversals.py index 4d9c80603c..8c8996df9d 100644 --- a/flax/nnx/traversals.py +++ b/flax/nnx/traversals.py @@ -18,6 +18,7 @@ from collections.abc import Callable, Mapping from typing import Any, overload +from collections.abc import Iterable from flax import struct @@ -118,6 +119,55 @@ def _flatten(xs: Any, prefix: tuple[Any, ...]) -> dict[Any, Any]: return _flatten(xs, ()) +def flatten_to_sequence( + xs: Mapping[Any, Any], + /, + *, + is_leaf: None | IsLeafCallable = None, +) -> list[tuple[Any, Any]]: + """Flatten a nested mapping. + + The nested keys are flattened to a tuple. See ``unflatten_mapping`` on how to + restore the nested mapping. + + Example:: + + >>> from flax import nnx + >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} + >>> flat_xs = nnx.traversals.flatten_mapping(xs) + >>> flat_xs + {('foo',): 1, ('bar', 'a'): 2} + + Note that empty mappings are ignored and will not be restored by + ``unflatten_mapping``. + + Args: + xs: a nested mapping + keep_empty_nodes: replaces empty mappings with + ``traverse_util.empty_node``. + is_leaf: an optional function that takes the next nested mapping and nested + keys and returns True if the nested mapping is a leaf (i.e., should not be + flattened further). + sep: if specified, then the keys of the returned mapping will be + ``sep``-joined strings (if ``None``, then keys will be tuples). + Returns: + The flattened mapping. + """ + assert isinstance( + xs, Mapping + ), f'expected Mapping; got {type(xs).__qualname__}' + result = [] + + def _flatten(xs: Any, prefix: tuple[Any, ...]): + if not isinstance(xs, Mapping) or (is_leaf and is_leaf(prefix, xs)): + result.append((prefix, xs)) + else: + for key, value in xs.items(): + _flatten(value, (*prefix, key)) + + _flatten(xs, ()) + return result + @overload def unflatten_mapping(xs: Mapping[tuple[Any, ...], Any], @@ -163,9 +213,15 @@ def unflatten_mapping(xs: Any, Returns: The nested mapping. """ - assert isinstance(xs, Mapping), f'expected Mapping; got {type(xs).__qualname__}' + if isinstance(xs, Mapping): + xs = xs.items() + + if not isinstance(xs, Iterable): + raise TypeError( + f'expected Mapping or Iterable; got {type(xs).__qualname__}' + ) result: dict[Any, Any] = {} - for path, value in xs.items(): + for path, value in xs: if sep is not None: path = path.split(sep) if value is empty_node: diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 7af20cdb73..4752a9b7bd 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -216,7 +216,7 @@ def copy_from(self, other: Variable[A]) -> None: def update_from_state(self, variable_state: VariableState[A]): vars_self = vars(self) vars_self['raw_value'] = variable_state.value - vars_self['_var_metadata'] = variable_state.get_metadata().copy() + vars_self['_var_metadata'] = variable_state._var_metadata.copy() @property def value(self) -> A: @@ -308,8 +308,7 @@ def copy(self: Variable[A]) -> Variable[A]: return obj def to_state(self: Variable[A]) -> VariableState[A]: - metadata = self.get_metadata() - return VariableState(type(self), self.raw_value, **metadata) + return VariableState(type(self), self.raw_value, **self._var_metadata) def __nnx_repr__(self): yield reprlib.Object(type=type(self))