Skip to content

How JAX implements JVP? Two passes of reverse mode? #10157

Answered by soraros
luweizheng asked this question in General
Discussion options

You must be logged in to vote

Jax's forward-mode is implemented using the dual number approach, and reverse-mode using forward-mode and transpose.
I wouldn't expect Jax to have a larger tracing overhead, and there should be no numerical computing overhead compared with jitted TF.

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
3 replies
@mattjj
Comment options

@soraros
Comment options

@mattjj
Comment options

Answer selected by luweizheng
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
4 participants