Skip to content

custom_jvp for selected arguments causes error for grad #18720

Answered by f0uriest
exenGT asked this question in General

You must be logged in to vote

I think it's because Jax passes primal and tangent as tuples of arrays to the jvp function, so you need to unpack them

Replies: 1 comment

You must be logged in to vote
0 replies
Answer selected by exenGT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants