-
I'm currently trying to reproduce a simpler version of a bug I'm facing in my code - essentially a jax.lax.cond is killing my gradients in my jacobian. Before I go mad, are there any gotchas I might be missing? I have a class that implements a method something like this:
This is used in some higher error function, which I then take a I noticed I was seeing zero gradients in parts of my jacobian in places where I saw them on an old reference implementation. I also noticed that if I removed the conditional, those gradients came back and matched the reference implementation. The following worked (In my debugging case I knew the conditional evaluated false):
The following did not work:
Nor did the following:
It seems that merely the presence of the As I mentioned, I'm still trying to reproduce this outside of my code base (where things seem to be behaving a little better), but before I go mad, are there any gotchas I ought to be looking out for ? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 19 replies
-
What do the false branches return? This might be some manifestation of this issue JAX FAQ: gradients contain NaN when using In any case, if you're able to put together a minimal reproduction, it would be easier to help. |
Beta Was this translation helpful? Give feedback.
-
Definitely have lots of impure stuff but I've registered pytrees.
Flatten/unflatten not being called though :/
…On Tue, 29 Aug 2023, 22:10 Jake Vanderplas, ***@***.***> wrote:
Another guess: does your class do any in-place mutation (i.e. methods that
have side-effects of the form self.attr = ...). If your functions are
impure, I could imagine that causing this kind of issue in autodiff.
—
Reply to this email directly, view it on GitHub
<#17341 (reply in thread)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJWOW23E2OCYYH455GV3TTXXZLE5ANCNFSM6AAAAAA4CD76NU>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
Just an update - while what I'm doing may not be totally 'kosher' as per @jakevdp 's replies above, I think it's worth showing a little bit of the techniques I used to migrate an OO codebase to Jax. In the ends I created superclass to inherit and follow a pattern to allow my OO objects to be correctly flattened. This works with transient dependencies and is closely related to a previous proposal for pydags, here: #7919 Some notes on the approach:
I'm actually overall pretty happy with what I've ended up with - while I wouldn't start from this place - adapting a code base to here has left me with something that's still quite neat without any really nasty hacks. The flatten/unflatten logic is still contained within each subclass and separates concerns nicely. Copilot is now able to write all my flatten/unflatten and getrefs/setrefs functions. I would be interested in seeing something like this brought to Jax itself to make migration to Jax a little easier... |
Beta Was this translation helpful? Give feedback.
Yes, definitely.
If you're updating attributes of
self
but not returningself
from your function, then your function is impure and I would not expect JAX transformations to work with your code.