Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

np.ndarray.__getitem__(jnp.ndarray) is incorrect #2040

Closed
trevorcai opened this issue Jan 22, 2020 · 1 comment
Closed

np.ndarray.__getitem__(jnp.ndarray) is incorrect #2040

trevorcai opened this issue Jan 22, 2020 · 1 comment

Comments

@trevorcai
Copy link
Contributor

trevorcai commented Jan 22, 2020

import jax
import jax.numpy as jnp
import numpy as np

arr = np.zeros((2, 8))
idxs = np.zeros((1, 4), dtype=np.int32)
jax_arr = jax.device_put(arr)
jax_idxs = jax.device_put(idxs)
print(arr[idxs].shape)
print(arr[jax_idxs].shape)
print(jax_arr[idxs].shape)
print(jax_arr[jax_idxs].shape)

prints:

(1, 4, 8)
(4, 8)
(1, 4, 8)
(1, 4, 8)

original numpy is mistreating jnp.ndarray for some reason; changing the shape of idxs can lead to various errors.

Continuing from above:
arr[jnp.zeros((2, 4), dtype=jnp.int32)].shape prints (4,)
arr[jnp.zeros((3, 4), dtype=jnp.int32)].shape raises: IndexError: too many indices for array
The other 3/4 combinations are all correct for these cases.

@trevorcai
Copy link
Contributor Author

Dupe of #620

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant