diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 42a2604042..fcded54936 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -264,10 +264,10 @@ def split( def split( # type: ignore[misc] self, first: filterlib.Filter, /, *filters: filterlib.Filter ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: - """Split a ``State`` into one or more ``State``'s. The - user must pass at least one ``Filter`` (i.e. :class:`Variable`), - and the filters must be exhaustive (i.e. they must cover all - :class:`Variable` types in the ``State``). + """Splits a :class:`flax.nnx.State` into one or more ``nnx.State``'s. + You must pass at least one NNX ``Filter`` (``flax.nnx.filterlib``) + (i.e. :class:`flax.nnx.Variable`), and the ``Filter``'s must be exhaustive + (i.e. they must cover all ``nnx.Variable`` types in the ``nnx.State``). Example usage:: @@ -285,10 +285,11 @@ def split( # type: ignore[misc] >>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat) Arguments: - first: The first filter - *filters: The optional, additional filters to group the state into mutually exclusive substates. + first: The first NNX ``Filter``. + *filters: The optional, additional NNX ``Filter``'s to group the + :class:`flax.nnx.State` into mutually exclusive sub-``State``'s. Returns: - One or more ``States`` equal to the number of filters passed. + One or more ``nnx.State``'s equal to the number of NNX ``Filter``'s passed. """ filters = (first, *filters) *states_, rest = _split_state(self.flat_state(), *filters) @@ -492,4 +493,4 @@ def create_path_filters(state: State): if isinstance(value, (variablelib.Variable, variablelib.VariableState)): value = value.value value_paths.setdefault(value, set()).add(path) - return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} \ No newline at end of file + return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}