diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 2339f5c168..2e6baa2268 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -64,7 +64,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]: class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]): - """A mapping that uses object id as the hash for the keys.""" + """A mapping that uses object ``id`` as the hash for the keys.""" def __init__( self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), / @@ -391,11 +391,12 @@ def _apply( def flatten( node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None ) -> tuple[GraphDef[Node], GraphState]: - """Flattens a graph node into a (graphdef, state) pair. + """Flattens a graph node into a ``(GraphDef, State)`` pair. + (:class:`flax.nnx.GraphDef`, :class:`flax.nnx.State`) Args: x: A graph node. - ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new + ref_index: A mapping from nodes to indexes. Defaults to ``None``. If it's not provided, a new empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references. """ @@ -1856,4 +1857,4 @@ def _unflatten_pytree( type(None), flatten=lambda x: ([], None), unflatten=lambda _, __: None, # type: ignore -) \ No newline at end of file +)