Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.indices returns numpy.ndarray read-only arrays #7437

Closed
LouisDesdoigts opened this issue Aug 2, 2021 · 2 comments
Closed

jnp.indices returns numpy.ndarray read-only arrays #7437

LouisDesdoigts opened this issue Aug 2, 2021 · 2 comments
Assignees
Labels
duplicate This issue or pull request already exists

Comments

@LouisDesdoigts
Copy link

LouisDesdoigts commented Aug 2, 2021

Hey Jax team!

I have found that jax.numpy.indices() actually returns an un-writeable numpy array, rather than a DeviceArray.
I found a work around for the problem, but I doubt this is the intended behavior.

This was found using:
jax 0.1.77
jaxlib 0.1.57
which I know are not the current stable release, but I am working with a package that has not been patched to work with 0.2+.

import jax.numpy as jnp
import numpy as onp

# Numpy returns a writeable numpy.ndarray
a, b = onp.indices((5,5))
print(a.flags["WRITEABLE"])
# <class 'numpy.ndarray'>, True

a += 5
# Does not cause error and works fine

# Jax.numpy returns a read-only numpy.ndarray (NOT DeviceArray)
a, b = jnp.indices((5,5))
print(type(a), a.flags["WRITEABLE"])
# <class 'numpy.ndarray'>, False

a += 5
# Throws "ValueError: output array is read-only"

# Re-casting array to DeviceArray fixes problem
a = jnp.array(a)
print(type(a))
# <class 'jaxlib.xla_extension.DeviceArray'>
a += 5
# Now the array is writeable without error
@LouisDesdoigts LouisDesdoigts added the bug Something isn't working label Aug 2, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 2, 2021

Hi - thanks for the question! It turns out the culprit is not jnp.indices, rather this is a long-standing bug/wart in JAX (see #1583). jnp.indices returns a device array, but when you write a, b = jnp.indices(...) it calls the __iter__ method, which outputs static numpy arrays.

There's a fix in the works at #3821 (not sure why it stalled, maybe @mattjj knows?).

Until that is resolved, a workaround is to do something like this:

ind = jnp.indices((5,5))
a = ind[0]
b = ind[1]

@jakevdp jakevdp added duplicate This issue or pull request already exists and removed bug Something isn't working labels Aug 2, 2021
@jakevdp jakevdp self-assigned this Aug 2, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 3, 2021

I'm going to close this - we can continue to track the issue in #1583

@jakevdp jakevdp closed this as completed Aug 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists
Projects
None yet
Development

No branches or pull requests

2 participants