diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a088..826c4b1b81 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -322,12 +322,13 @@ def set_attributes( ) def train(self, **attributes): - """Sets the Module to training mode. + """Sets the :class:`flax.nnx.Module` to training mode. - ``train`` uses ``set_attributes`` to recursively set attributes ``deterministic=False`` - and ``use_running_average=False`` of all nested Modules that have these attributes. - Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` - Modules. + ``nnx.Module.train`` uses :func:`flax.nnx.Module.set_attributes`` to recursively set + attributes ``deterministic=False`` and ``use_running_average=False`` of all nested + ``nnx.Module``'s that have these attributes. It is primarily used to control the + runtime behavior of the :class:`flax.nnx.Dropout` and :class:`flax.nnx.BatchNorm` + ``nnx.Module``'s. Example:: @@ -348,7 +349,7 @@ def train(self, **attributes): (False, False) Args: - **attributes: additional attributes passed to ``set_attributes``. + **attributes: Additional attributes passed to ``set_attributes``. """ return self.set_attributes( deterministic=False,