Skip to content

Commit

Permalink
Update NNX merge docs in graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 authored Nov 29, 2024
1 parent 6bc9858 commit 92e8d42
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,10 +1297,18 @@ def merge(
/,
*states: tp.Mapping[KeyT, tp.Any],
) -> A:
"""The inverse of :func:`split`.
``merge`` takes a :class:`GraphDef` and one or more :class:`State`'s and creates
a new node with the same structure as the original node.
"""The inverse of :func:`flax.nnx.split`.
``nnx.merge`` takes a :class:`flax.nnx.GraphDef` and one or more :class:`flax.nnx.State`'s
and creates a new node with the same structure as the original node.
Recall: :func:`flax.nnx.split` is used to represent a :class:`flax.nnx.Module`
by: 1) a static ``nnx.GraphDef`` that captures its Pythonic static information;
and 2) one or more :class:`flax.nnx.Variable` ``nnx.State``'(s) that capture
its ``jax.Array``'s in the form of JAX pytrees.
``nnx.merge`` is used in conjunction with ``nnx.split`` to switch seamlessly
between stateful and stateless representations of the graph.
Example usage::
Expand All @@ -1320,17 +1328,17 @@ def merge(
>>> assert isinstance(new_node.batch_norm, nnx.BatchNorm)
>>> assert isinstance(new_node.linear, nnx.Linear)
:func:`split` and :func:`merge` are primarily used to interact directly with JAX
transformations, see
`Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
``nnx.split`` and ``nnx.merge`` are primarily used to interact directly with JAX
transformations (refer to
`Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__
for more information.
Args:
graphdef: A :class:`GraphDef` object.
state: A :class:`State` object.
*states: Additional :class:`State` objects.
graphdef: A :class:`flax.nnx.GraphDef` object.
state: A :class:`flax.nnx.State` object.
*states: Additional :class:`flax.nnx.State` objects.
Returns:
The merged :class:`Module`.
The merged :class:`flax.nnx.Module`.
"""
state = State.merge(state, *states)
node = unflatten(graphdef, state)
Expand Down Expand Up @@ -1852,4 +1860,4 @@ def _unflatten_pytree(
type(None),
flatten=lambda x: ([], None),
unflatten=lambda _, __: None, # type: ignore
)
)

0 comments on commit 92e8d42

Please sign in to comment.