Skip to content

Commit

Permalink
Update NNX flatten docs in graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 15, 2024
1 parent 5d896bc commit 5790703
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:


class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]):
"""A mapping that uses object id as the hash for the keys."""
"""A mapping that uses object ``id`` as the hash for the keys."""

def __init__(
self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), /
Expand Down Expand Up @@ -391,11 +391,12 @@ def _apply(
def flatten(
node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
) -> tuple[GraphDef[Node], GraphState]:
"""Flattens a graph node into a (graphdef, state) pair.
"""Flattens a graph node into a ``(GraphDef, State)`` pair.
(:class:`flax.nnx.GraphDef`, :class:`flax.nnx.State`)
Args:
x: A graph node.
ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new
ref_index: A mapping from nodes to indexes. Defaults to ``None``. If it's not provided, a new
empty dictionary is created. This argument can be used to flatten a sequence of graph
nodes that share references.
"""
Expand Down Expand Up @@ -1856,4 +1857,4 @@ def _unflatten_pytree(
type(None),
flatten=lambda x: ([], None),
unflatten=lambda _, __: None, # type: ignore
)
)

0 comments on commit 5790703

Please sign in to comment.