Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded versions of JAX and diffrax due to JAX bug #190

Closed
DanPuzzuoli opened this issue Feb 17, 2023 · 3 comments · Fixed by #266
Closed

Bounded versions of JAX and diffrax due to JAX bug #190

DanPuzzuoli opened this issue Feb 17, 2023 · 3 comments · Fixed by #266

Comments

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Feb 17, 2023

The change in #189 should be removed when the corresponding JAX bug discussed here is fixed.

Edit:

@DanPuzzuoli
Copy link
Collaborator Author

As of JAX 0.4.7 the change in #189 can no longer be used to bypass the issue with JAX. For now we may need to set 0.4.6 as an upper bound on the JAX version.

@DanPuzzuoli DanPuzzuoli changed the title Revert #189 when JAX bug fixed Revert temporary changes when JAX bug fixed Mar 30, 2023
@DanPuzzuoli
Copy link
Collaborator Author

In relation to the previous comment #210 now sets 0.4.6 to the latest version of JAX that works with Dynamics. For now this is just to get the CI tests working - hopefully this will be resolved shortly and these restrictions can be removed.

@DanPuzzuoli DanPuzzuoli changed the title Revert temporary changes when JAX bug fixed Bounded versions of JAX and diffrax due to JAX bug Jun 1, 2023
@DanPuzzuoli
Copy link
Collaborator Author

DanPuzzuoli commented Jun 1, 2023

As this issue has continued to evolve, I'll make this comment a summary of what's going on, to be edited as new issues arise.

JAX 0.4.4 introduced a new way of tracing functions. This new way unfortunately has a bug that breaks the perturbation module, but until JAX 0.4.7, it was possible to revert back to the old way. As this bug is still not fixed, it is currently necessary to bound the JAX version to 0.4.6. The latest version of diffrax now also assumes the latest version of JAX, and hence we need to bound diffrax and equinox to suitable versions. Hopefully this bug gets fixed soon and we can go back to using the latest version of all of these packages.

List of issues/PRs:

Note: the issue in the perturbation module arises due to a weird combination of JAX elements outlined here. It would be worth checking if this problem occurs if we sub a diffrax solver in for jax_odeint. My guess is it won't make a difference, but it's worth a shot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant