Skip to content

Commit

Permalink
Add flax.nnx.value_and_grad docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Nov 29, 2024
1 parent d31f290 commit 7e167af
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,63 @@ def value_and_grad(
tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
"""A reference-aware version of `jax.value_and_grad <https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html#jax.value_and_grad>`_
that can handle `flax.nnx.Modules <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
/ graph nodes as arguments.
``value_and_grad`` creates a function (``f``) that evaluates both ``f`` and the gradient of ``f``.
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
Example::
>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> value_and_grad_fn = nnx.value_and_grad(loss_fn)
...
>>> values, grads = value_and_grad_fn(m, x, y)
...
>>> print(f"{jax.tree.map(jnp.shape, grads)}\n{values}")
State({
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
})
1.648836612701416
Args:
f: A function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers. It
should return a scalar (which includes arrays with shape ``()`` but not
arrays with shape ``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default ``False``.
holomorphic: Optional, bool. Indicates whether ``f`` is promised to be
holomorphic. If ``True``, inputs and outputs must be complex. Default ``False``.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (``float0``). Default ``False``.
Returns:
A function with the same arguments as ``f`` that evaluates both ``f``
and the gradient of ``f`` and returns them as a pair (a two-element
tuple).
"""
if f is Missing:
return functools.partial(
value_and_grad,
Expand Down

0 comments on commit 7e167af

Please sign in to comment.