-
Hi, I would like to implement custom derivative for selected arguments of a function. My example code is below:
Here,
An error message appears as follows:
What causes this error, and is there a way to resolve this? |
Beta Was this translation helpful? Give feedback.
Answered by
f0uriest
Nov 29, 2023
Replies: 1 comment
-
I think it's because Jax passes primal and tangent as tuples of arrays to the jvp function, so you need to unpack them |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
exenGT
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think it's because Jax passes primal and tangent as tuples of arrays to the jvp function, so you need to unpack them