-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
unexpected vmap
error due to commit c36e1f7
#25289
Comments
Thanks for the clear report! |
@marcocuturi could you minimize this? There's a lot going on in this code that we aren't familiar with, and it's much harder for us to minimize unfamiliar code than for you to minimize your own code. Think of the time for us to debug this as exponential in the length of the repro you give us.
Are any of lineax's tests failing? |
No, all of them are passing with the latest version of JAX/lineax. |
thanks a lot @mattjj for taking a look! Here's a simpler example crafted by @michalk8 and @Algue-Rythme demonstrating the problem, which arises indeed from
|
Yeah, I've also been seeing widespread failures in Diffrax's test suite, due to what looks like a totally different vmap failure. I've spent most of today digging through this and haven't identified a root cause yet. It might take a while to update the JAX ecosystem to be compatible with this version of JAX. |
See: - jax-ml/jax#25289 - patrick-kidger/diffrax#532 The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
Okay, I think I've identified the root cause: with the latest changes, the batch interpreter has become a dynamic trace, i.e. it calls batch rules when it previously wouldn't. This meant that a lot of arrays were having their nonbatch dimensions now being turned into batch dimensions! With that problem identified it's been a relatively simple matter to update a couple of batching rules in Equinox to handle this new calling case appropriately. @marcocuturi @michalk8 can you try patrick-kidger/equinox#907 on your full example / on your tests? If it passes then I'll do a new release of Equinox that is compatible with latest JAX. |
See: - jax-ml/jax#25289 - patrick-kidger/diffrax#532 The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
See: - jax-ml/jax#25289 - patrick-kidger/diffrax#532 The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
Nice find @patrick-kidger ! That’s right, actually everything is a dynamic tracer now. No more automatic rules-only-called-based-on-data-dependence, though rules themselves can choose to behave based on dependence. I believe it gives rules strictly more power/expressiveness. |
Hey @mattjj, can we consolidate this kind of knowledge into an updated version of Autodidax? |
See: - jax-ml/jax#25289 - patrick-kidger/diffrax#532 The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
Works great, thanks! |
Description
Hi,
@michalk8 and I noticed a bug in our tests here that occurs in the latest JAX version. After doing git-bisect, we found the bad commit to be: c36e1f7.
Here's the traceback with
JAX_TRACEBACK_FILTERING=off
:It might be that the bug is coming from transformations created by lineax, as the test doesn't fail when using the CG solver from JAX, (the test still fails, but only because of the precision, not the above
ValueError
).Code to reproduce:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: