Skip to content

Commit

Permalink
Update NNX Module train docs in module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 17, 2024
1 parent 6bc9858 commit d6a87ce
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,13 @@ def set_attributes(
)

def train(self, **attributes):
"""Sets the Module to training mode.
"""Sets the :class:`flax.nnx.Module` to training mode.
``train`` uses ``set_attributes`` to recursively set attributes ``deterministic=False``
and ``use_running_average=False`` of all nested Modules that have these attributes.
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
Modules.
``nnx.Module.train`` uses :func:`flax.nnx.Module.set_attributes`` to recursively set
attributes ``deterministic=False`` and ``use_running_average=False`` of all nested
``nnx.Module``'s that have these attributes. It is primarily used to control the
runtime behavior of the :class:`flax.nnx.Dropout` and :class:`flax.nnx.BatchNorm`
``nnx.Module``'s.
Example::
Expand All @@ -348,7 +349,7 @@ def train(self, **attributes):
(False, False)
Args:
**attributes: additional attributes passed to ``set_attributes``.
**attributes: Additional attributes passed to ``set_attributes``.
"""
return self.set_attributes(
deterministic=False,
Expand Down

0 comments on commit d6a87ce

Please sign in to comment.