Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove bound on JAX/Diffrax versions #266

Merged
merged 16 commits into from
Oct 30, 2023

Conversation

DanPuzzuoli
Copy link
Collaborator

@DanPuzzuoli DanPuzzuoli commented Oct 12, 2023

Summary

Closes #190

As it is unclear when the bug causing an error in the perturbation module will get fixed in JAX, I just tried returning to the issue and have figured out a simple workaround to bypass the bug. I realistically should have tried this sooner but I didn't realize the issue would hang for this long.

The workaround fixes the perturbation module, but there are now many errors/warnings coming up in many different tests. This is to be expected - there have been over 10 minor releases of JAX since we put a bound on the version. I'll need to make my way through each folder and figure out how to fix it.

Details and comments

I'm working through each submodule to fix errors/warnings in the tests (strikethrough indicates no warnings/errors being raised when running tests, and no comment or strikethrough means the module hasn't been checked yet).

  • arraylias
  • backend
  • models
    • No failures/errors, but a warning about casting complex -> real pops up in many places, e.g. test_generator_model. Currently unable to track down what's raising this.
  • perturbation
  • pulse
  • signals
  • solvers
    • An error is occurring in a test involving merging t_span and t_eval with JAX test.dynamics.solvers.test_jax_odeint.TestJaxOdeint.test_transformations_t_eval_arg_overlap. This is the result of a different bug in JAX (see this discussion). I need to review the status of what is supposed to be possible with this merging, but I don't think this is used heavily and only occurs if t_span[-1] == t_eval[-1]. We could potentially just leave this as a known issue for now as this could be solved on the JAX side before the next release. Alternatively we could make it so that jax_odeint only takes in t_span, and the user should supply the full range of time values as if directly calling odeint. This will mess with the interface a bit but this is a technicality that has caused many hours to be lost.
      • Update: the underlying issue in JAX has been fixed and I've verified that it fixes the problem. Can hold out for the next JAX release (which seems to be every few weeks).
    • This module also has many complex -> real casting warnings.
    • A deprecation warning for automatic conversion of a size 1 array into a scalar is deprecated (the line alpha[0] = y0.conj().T @ projection appears in the lanczos solver).

@DanPuzzuoli
Copy link
Collaborator Author

DanPuzzuoli commented Oct 16, 2023

It seems the complex -> real casting warnings only ever occur in differentiated functions, which means this is probably some weird behaviour with JAX transformations (i.e. these warnings don't get raised if the function is not differentiated). Need to try to find a minimal example to understand what's going on.

Edit: Okay it turns out that this warning gets raised when computing the gradient of a function that involves a tensordot of a real array with a complex array. Will see what the JAX developers say about this: it seems like this should either be raised when not computing the gradient (for consistency), or should not be raised during gradient computation.

Asked here: jax-ml/jax#18133

@DanPuzzuoli DanPuzzuoli marked this pull request as ready for review October 20, 2023 21:04
@DanPuzzuoli
Copy link
Collaborator Author

As it currently stands:

  • The workaround is implemented in dyson_magnus.py in the function _setup_dyson_rhs_jax.
  • The import warning about JAX version restrictions has been removed, and the bounds in version bounds in files like setup.py and the test environment setups have been removed.
  • A test has been added to test_dyson_magnus.py that contains a minimal reproduction of the error with the original version of _setup_dyson_rhs_jax. The test is setup to pass if an exception is thrown when executing this code. Once this test starts failing, we will know that the original version of _setup_dyson_rhs_jax should work again, and at that point we can revert it. (The documentation in the test states this.)
  • I've added a release note to this PR with an upgrade note stating that users should now use the latest version of JAX. Unfortunately, the complex casting warning issue hasn't been resolved yet, and for now I've added a "known issue" note in relation to this warning.

@DanPuzzuoli DanPuzzuoli requested a review from to24toro October 27, 2023 16:01
@DanPuzzuoli DanPuzzuoli merged commit 594481c into qiskit-community:main Oct 30, 2023
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bounded versions of JAX and diffrax due to JAX bug
2 participants