-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove bound on JAX/Diffrax versions #266
Conversation
9692111
to
a6880e6
Compare
It seems the complex -> real casting warnings only ever occur in differentiated functions, which means this is probably some weird behaviour with JAX transformations (i.e. these warnings don't get raised if the function is not differentiated). Need to try to find a minimal example to understand what's going on. Edit: Okay it turns out that this warning gets raised when computing the gradient of a function that involves a tensordot of a real array with a complex array. Will see what the JAX developers say about this: it seems like this should either be raised when not computing the gradient (for consistency), or should not be raised during gradient computation. Asked here: jax-ml/jax#18133 |
As it currently stands:
|
Co-authored-by: Kento Ueda <[email protected]>
Summary
Closes #190
As it is unclear when the bug causing an error in the perturbation module will get fixed in JAX, I just tried returning to the issue and have figured out a simple workaround to bypass the bug. I realistically should have tried this sooner but I didn't realize the issue would hang for this long.
The workaround fixes the perturbation module, but there are now many errors/warnings coming up in many different tests. This is to be expected - there have been over 10 minor releases of JAX since we put a bound on the version. I'll need to make my way through each folder and figure out how to fix it.
Details and comments
I'm working through each submodule to fix errors/warnings in the tests (strikethrough indicates no warnings/errors being raised when running tests, and no comment or strikethrough means the module hasn't been checked yet).
arrayliasbackendperturbationpulsesignalst_span
andt_eval
with JAXtest.dynamics.solvers.test_jax_odeint.TestJaxOdeint.test_transformations_t_eval_arg_overlap
. This is the result of a different bug in JAX (see this discussion). I need to review the status of what is supposed to be possible with this merging, but I don't think this is used heavily and only occurs ift_span[-1] == t_eval[-1]
. We could potentially just leave this as a known issue for now as this could be solved on the JAX side before the next release. Alternatively we could make it so thatjax_odeint
only takes int_span
, and the user should supply the full range of time values as if directly callingodeint
. This will mess with the interface a bit but this is a technicality that has caused many hours to be lost.A deprecation warning for automatic conversion of a size 1 array into a scalar is deprecated (the linealpha[0] = y0.conj().T @ projection
appears in the lanczos solver).