How JAX implements JVP? Two passes of reverse mode? #10157
-
Hi, I am new to JAX. As TensorFlow is mainly for deep learning, it is using reverse mode automatic differentiation. If we want to use forward mode automatic differentiation in TF, we must use two passes of reverse mode. I want to know how JAX works on forward mode AD (JVP). There is a document Autodidax: JAX core from scratch. But I do not get the key ideas. If JAX is using another way of forward mode, what about its computation overhead compared with TF's two passes of reverse mode? If JAX is using the same way as TF, maybe the differences between JAX and TF are numpy-like APIs, JVP API, aynchronous dispatch and other features...? Could anyone explain it a little bit? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Jax's forward-mode is implemented using the dual number approach, and reverse-mode using forward-mode and transpose. |
Beta Was this translation helpful? Give feedback.
-
A brief explanation: #9328 (comment)
The "recursion" |
Beta Was this translation helpful? Give feedback.
Jax's forward-mode is implemented using the dual number approach, and reverse-mode using forward-mode and transpose.
I wouldn't expect Jax to have a larger tracing overhead, and there should be no numerical computing overhead compared with jitted TF.