Skip to content

Commit

Permalink
Update NNX unflatten docs in graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 15, 2024
1 parent 6bc9858 commit 3c0aa12
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,20 +474,20 @@ def unflatten(
index_ref: dict[Index, tp.Any] | None = None,
index_ref_cache: dict[Index, tp.Any] | None = None,
) -> Node:
"""Unflattens a graphdef into a node with the given state.
"""Unflattens a :class:`flax.nnx.GraphDef` into a node with the given state.
Args:
graphdef: A GraphDef instance.
state: A State instance.
graphdef: A :class:`flax.nnx.GraphDef` instance.
state: A :class:`flax.nnx.State` instance.
index_ref: A mapping from indexes to nodes references found during the graph
traversal, defaults to None. If not provided, a new empty dictionary is
created. This argument can be used to unflatten a sequence of (graphdef, state)
traversal. Defaults to ``None``. If it's not provided, a new empty dictionary is
created. This argument can be used to unflatten a sequence of ``(graphdef, state)``
pairs that share the same index space.
index_ref_cache: A mapping from indexes to existing nodes that can be reused.
When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
object in an empty state and then filled by the unflatten process, as a result
When an reference is reused, :func:`flax.nnx.GraphNodeImpl.clear` is called to leave the
object in an empty state and then filled by the ``nnx.unflatten`` process. As a result,
existing graph nodes are mutated to have the new content/topology
specified by the graphdef.
specified by the ``graphdef`` (an ``nnx.GraphDef`` instance).
"""
if isinstance(state, State):
state = state.raw_mapping # type: ignore
Expand Down Expand Up @@ -1852,4 +1852,4 @@ def _unflatten_pytree(
type(None),
flatten=lambda x: ([], None),
unflatten=lambda _, __: None, # type: ignore
)
)

0 comments on commit 3c0aa12

Please sign in to comment.