Skip to content

how to implement the vjp for the basic primitive #17529

Closed Answered by jakevdp
mmmeee1111 asked this question in General
Discussion options

You must be logged in to vote

Here is where the autodiff rules are defined for the real_p and imag_p primitives in JAX: https://github.com/google/jax/blob/292deef6fda9f639fcecd9883a1112825a1eb54f/jax/_src/lax/lax.py#L1890-L1896

They use ad.deflinear2, which is a utility for registering jvp and transpose rules for linear primitives: https://github.com/google/jax/blob/292deef6fda9f639fcecd9883a1112825a1eb54f/jax/_src/interpreters/ad.py#L517-L519

For more general information regarding autodiff in JAX, some good resources from the docs are

And if you'…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@mmmeee1111
Comment options

@jakevdp
Comment options

Answer selected by mmmeee1111
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants