You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have found that jax.numpy.indices() actually returns an un-writeable numpy array, rather than a DeviceArray.
I found a work around for the problem, but I doubt this is the intended behavior.
This was found using:
jax 0.1.77
jaxlib 0.1.57
which I know are not the current stable release, but I am working with a package that has not been patched to work with 0.2+.
importjax.numpyasjnpimportnumpyasonp# Numpy returns a writeable numpy.ndarraya, b=onp.indices((5,5))
print(a.flags["WRITEABLE"])
# <class 'numpy.ndarray'>, Truea+=5# Does not cause error and works fine# Jax.numpy returns a read-only numpy.ndarray (NOT DeviceArray)a, b=jnp.indices((5,5))
print(type(a), a.flags["WRITEABLE"])
# <class 'numpy.ndarray'>, Falsea+=5# Throws "ValueError: output array is read-only"# Re-casting array to DeviceArray fixes problema=jnp.array(a)
print(type(a))
# <class 'jaxlib.xla_extension.DeviceArray'>a+=5# Now the array is writeable without error
The text was updated successfully, but these errors were encountered:
Hi - thanks for the question! It turns out the culprit is not jnp.indices, rather this is a long-standing bug/wart in JAX (see #1583). jnp.indices returns a device array, but when you write a, b = jnp.indices(...) it calls the __iter__ method, which outputs static numpy arrays.
There's a fix in the works at #3821 (not sure why it stalled, maybe @mattjj knows?).
Until that is resolved, a workaround is to do something like this:
Hey Jax team!
I have found that jax.numpy.indices() actually returns an un-writeable numpy array, rather than a DeviceArray.
I found a work around for the problem, but I doubt this is the intended behavior.
This was found using:
jax 0.1.77
jaxlib 0.1.57
which I know are not the current stable release, but I am working with a package that has not been patched to work with 0.2+.
The text was updated successfully, but these errors were encountered: