Limitations of forward mode AD #9328
-
My question might be more to do with general AD and JAX implementation of it, rather than JAX itself, so I apologize in advance. Given the following function, how can one find its derivative with respect to input array in forward mode differentiation? def vdw_energy(x):
r = (
(x1[0, 0] - x2[1, 0]) ** 2
+ (x1[0, 1] - x2[1, 1]) ** 2
+ (x1[0, 2] - x2[1, 2]) ** 2
)
E = 4 * (1 / r) ** 12 - (1 / r) ** 6
return E I ask because as per my understanding to get derivatives with respect to individual coordinates in above (each element of x) we need to "seed" the forward calculations with all elements being zero, except the desired variable. But as expected in above case it would lead to incorrect result as it changes the actual distances between particles. e.g. import jax as jx
import jax.numpy as jnp
x = jnp.array([[10.0, 10.0, 10.0], [11.0, 10.0, 10.0]])
f_prime = jx.jacfwd(vdw_energy)(x)
# DeviceArray([[ 84., 0., 0.], [-84., 0., 0.]], dtype=float32)
x = jnp.array([[10.0, 0.0, 0.0], [11.0, 10.0, 10.0]]) # for x[0,0] component
f_prime = jx.jacfwd(vdw_energy)(x)
# DeviceArray([[-9.053339e-16, -9.053339e-15, -9.053339e-15],[ 9.053339e-16, 9.053339e-15, 9.053339e-15]],dtype=float32) How exactly JAX is arriving at the correct results in forward mode? Background: I would like to get same results using Boost autodiff framework, which only supports forward mode and only returns aggregated gradients. But I am unsure how exactly to go about it. Therefore i would like to know how jax did it in forward mode? I have also asjed the same question in AI stackexchange, but no replies yet: https://ai.stackexchange.com/questions/34299/limitations-of-forward-mode-ad |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 11 replies
-
brief explanation: Edited: |
Beta Was this translation helpful? Give feedback.
-
Just to confirm, If I understand it correctly, JAX fwd mode does not use the conventional approach of pushing dual numbers through functions. Rather it used analytical expressions for derivatives of elementary operations to directly compute jvp. Am I correct? |
Beta Was this translation helpful? Give feedback.
-
Note: You can use |
Beta Was this translation helpful? Give feedback.
brief explanation:
In forward-mode AD, we define jacobian-vector-product function
jvp(f, x, v)
for each basic operatorf
.Where
x
is the point to differentiate the function, andv
is the vector to product.And
jvp(f, x, v) == J_f[x]v
for allx
andv
, whereJ_f[x]
is the jacobian matrix off
at the pointx
.For the function composition
h(x)=g(f(x))
, we can recursively obtain its jacobian-vector-product function.Namely
jvp(h, x, v) = jvp(g, f(x), jvp(f, x, v))
, sinceJ_h[x]v=J_g[f(x)]J_f[x]v
.Finally, we vectorize
jvp
w.r.t.v
, and getjacfwd(f, x) = jvp_vectorized(f, x, I)
, whereI
is identity matrix.Edited:
Actually,
jax.jvp
does compute gradient and function value in a single pass(the…