Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing numpy array with DeviceArray: index interpreted as tuple #620

Closed
joschu opened this issue Apr 17, 2019 · 10 comments
Closed

Indexing numpy array with DeviceArray: index interpreted as tuple #620

joschu opened this issue Apr 17, 2019 · 10 comments
Labels
bug Something isn't working

Comments

@joschu
Copy link
Contributor

joschu commented Apr 17, 2019

When you try to index a numpy ndarray with a DeviceArray, the numpy array tries to interpret the jax array as a tuple.

import numpy as onp
import jax.numpy as np
x = onp.zeros((5,7))
np_idx = onp.array([1,2,3])
jax_idx = np.array([1,2,3])
x[np_idx]
x[jax_idx] # <- raises IndexError

Workaround: put jax_idx in a singleton tuple x[(jax_idx,)]

This bug resulted in a confusing situation where my function worked when decorated by jax.jit but had a shape mismatch when called on a numpy array.

@mattjj mattjj added the bug Something isn't working label Apr 17, 2019
@mattjj
Copy link
Collaborator

mattjj commented Apr 17, 2019

This is related to #553, though in that case the user was only calling jax.numpy (but the surprising behavior cropped up because jax.numpy.arange actually just calls onp.arange, producing a plain ndarray).

To fix this, we'll either need to get DeviceArrays to look like regular ndarrays well enough so that Numpy's indexing treats them like regular ndarrays, or we'll need to set up our own overrides, perhaps using ideas like #611.

@mattjj
Copy link
Collaborator

mattjj commented Apr 18, 2019

This is a great bug.

I asked our friendly neighborhood NumPy expert @shoyer about it, and after we poked around the NumPy code a bit, he noticed something awesome: this example would have worked if your index array had more than 32 elements in it. Indeed:

import numpy as onp
import jax.numpy as np

x = onp.arange(35).reshape((5,7))  # changed
np_idx = onp.array([1,2,3] * 11)   # changed
jax_idx = np.array([1,2,3] * 11)   # changed

assert onp.allclose(x[np_idx], x[jax_idx])

That breaks if you change the 11s to 10s, though.

@shoyer also pointed out that this example will work if you write x[jax_idx, :], i.e. you include the colons.

Not sure yet if we have a way to mitigate this issue, but I couldn't wait to report this NumPy oddity.

@shoyer
Copy link
Collaborator

shoyer commented Apr 18, 2019

When I try this example, I also see a deprecation warning from NumPy:
/Users/shoyer/miniconda3/envs/jax-py36/bin/ipython:7: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use arr[tuple(seq)] instead of arr[seq]. In the future this will be interpreted as an array index, arr[np.array(seq)], which will result either in an error or a different result.

So hopefully that's a clue that something may be amiss.

@mattjj
Copy link
Collaborator

mattjj commented Apr 18, 2019

Reading that warning again, it makes it sound like this use case will start working for us in the future, when the behavior is removed.

One path forward in the short term might be to add some code to this function in multiarray/mapping.c to check for an __array__ method (i.e. call PyArray_FromArrayAttr) on the indexing object (like our DeviceArray exposes) and call it. That is, maybe we can make this NumPy function respect that style of duck typing of ndarrays, not just proper subclassing. (EDIT: also apparently suggested in this comment.)

@shoyer
Copy link
Collaborator

shoyer commented Apr 18, 2019

Yes, once the deprecation cycle finishes this should start working for JAX.

@levskaya
Copy link
Collaborator

Just remarking that I just helped an intern debug this issue again, so it's still happening circa July 2020.

@nsavinov
Copy link

I have also just spent some debugging it.

Deprecation warning helped, but it was slightly confusing: rather than telling me the error already happened, it told that the behaviour won't be supported in the future.

@nsavinov
Copy link

Maybe adding this gotcha to sharp bits would help new users? https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

copybara-service bot pushed a commit to google/flax that referenced this issue Dec 2, 2020
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345140147
copybara-service bot pushed a commit to google/flax that referenced this issue Dec 2, 2020
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345140147
copybara-service bot pushed a commit to google/flax that referenced this issue Dec 2, 2020
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345160314
@hawkinsp
Copy link
Collaborator

I sent a PR to fix this in upstream NumPy (numpy/numpy#21029). Let's see if it works!

@hawkinsp
Copy link
Collaborator

The upstream NumPy PR was merged! So I'm going to declare this fixed, even though you're going to have to wait for NumPy 1.23 to get it...

callumtilbury added a commit to uoe-agents/revisiting-maddpg that referenced this issue Jun 6, 2022
…NumPy < 1.23—hence needed to use an unstable release. See jax-ml/jax#620 for more details.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants