how to implement the vjp for the basic primitive #17529
-
Hi, experts! I am learning to implement auto-diff for basic primitive, for example, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
@jakevdp could you help on this? |
Beta Was this translation helpful? Give feedback.
-
Here is where the autodiff rules are defined for the They use For more general information regarding autodiff in JAX, some good resources from the docs are
And if you're not working with primitives, but rather want to define custom JVP / VJP rules for a function, there is some information here: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html Hope that helps! |
Beta Was this translation helpful? Give feedback.
Here is where the autodiff rules are defined for the
real_p
andimag_p
primitives in JAX: https://github.com/google/jax/blob/292deef6fda9f639fcecd9883a1112825a1eb54f/jax/_src/lax/lax.py#L1890-L1896They use
ad.deflinear2
, which is a utility for registering jvp and transpose rules for linear primitives: https://github.com/google/jax/blob/292deef6fda9f639fcecd9883a1112825a1eb54f/jax/_src/interpreters/ad.py#L517-L519For more general information regarding autodiff in JAX, some good resources from the docs are
And if you'…