-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Indexing numpy array with DeviceArray: index interpreted as tuple #620
Comments
This is related to #553, though in that case the user was only calling To fix this, we'll either need to get |
This is a great bug. I asked our friendly neighborhood NumPy expert @shoyer about it, and after we poked around the NumPy code a bit, he noticed something awesome: this example would have worked if your index array had more than 32 elements in it. Indeed: import numpy as onp
import jax.numpy as np
x = onp.arange(35).reshape((5,7)) # changed
np_idx = onp.array([1,2,3] * 11) # changed
jax_idx = np.array([1,2,3] * 11) # changed
assert onp.allclose(x[np_idx], x[jax_idx]) That breaks if you change the 11s to 10s, though. @shoyer also pointed out that this example will work if you write Not sure yet if we have a way to mitigate this issue, but I couldn't wait to report this NumPy oddity. |
When I try this example, I also see a deprecation warning from NumPy: So hopefully that's a clue that something may be amiss. |
Reading that warning again, it makes it sound like this use case will start working for us in the future, when the behavior is removed. One path forward in the short term might be to add some code to this function in multiarray/mapping.c to check for an |
Yes, once the deprecation cycle finishes this should start working for JAX. |
Just remarking that I just helped an intern debug this issue again, so it's still happening circa July 2020. |
I have also just spent some debugging it. Deprecation warning helped, but it was slightly confusing: rather than telling me the error already happened, it told that the behaviour won't be supported in the future. |
Maybe adding this gotcha to sharp bits would help new users? https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html |
jax-ml/jax#3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: jax-ml/jax#620 (comment) Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345140147
jax-ml/jax#3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: jax-ml/jax#620 (comment) Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345140147
jax-ml/jax#3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: jax-ml/jax#620 (comment) Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345160314
I sent a PR to fix this in upstream NumPy (numpy/numpy#21029). Let's see if it works! |
The upstream NumPy PR was merged! So I'm going to declare this fixed, even though you're going to have to wait for NumPy 1.23 to get it... |
…NumPy < 1.23—hence needed to use an unstable release. See jax-ml/jax#620 for more details.
When you try to index a numpy ndarray with a DeviceArray, the numpy array tries to interpret the jax array as a tuple.
Workaround: put
jax_idx
in a singleton tuplex[(jax_idx,)]
This bug resulted in a confusing situation where my function worked when decorated by jax.jit but had a shape mismatch when called on a numpy array.
The text was updated successfully, but these errors were encountered: