diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 42a2604042..0a968e860a 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -372,10 +372,10 @@ def filter( def merge( state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V] ) -> State[K, V]: - """The inverse of :meth:`split() `. + """The inverse of :func:`flax.nnx.State.state.split`. - ``merge`` takes one or more ``State``'s and creates - a new ``State``. + ``nnx.State.state.merge`` takes one or more :class:`flax.nnx.State`'s + and creates a new ``nnx.State``. Example usage:: @@ -398,10 +398,10 @@ def merge( >>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all() Args: - state: A ``State`` object. - *states: Additional ``State`` objects. + state: A :class:`flax.nnx.State` object. + *states: Additional ``nnx.State`` objects. Returns: - The merged ``State``. + The merged ``nnx.State``. """ if not states: if isinstance(state, State): @@ -492,4 +492,4 @@ def create_path_filters(state: State): if isinstance(value, (variablelib.Variable, variablelib.VariableState)): value = value.value value_paths.setdefault(value, set()).add(path) - return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} \ No newline at end of file + return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}