-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX elements transposed relative to NumPy #8899
Comments
What versions of jax and jaxlib are you using? With the most recent jax and jaxlib releases, I find that the jax result matches the numpy result. |
The numpy version might be relevant as well, because the first thing matplotlib does is check |
I ran the above code using the CPU backend and Furthermore, I tested the same code snippet on a different computer with |
I've narrowed it down to a matplotlib issue. In matplotlib 3.2.2 and older, the plot shows up correctly. In matplotlib 3.3.0 and newer, the plot is incorrect. You might look through the matplotlib changelog to figure out what's changed. In the meantime, you should know that in general it is not 100% safe to pass JAX arrays to packages that expect numpy arrays as input, because those packages do not always do correct input validation (in this case, matplotlib fails to call |
I'm going to close, because I don't think is a JAX issue. |
I am pretty sure that this changed at some point when updating jax or jaxlib and I never had a matplotlib version older than 3.4.1 installed on my system. |
I can reproduce the "correct" histogram with |
Sure, binary compatibility with numpy is not a sensible goal. However, matplotlib certainly is not any arbitrary library. Furthermore, the input validation you are referencing is not meant to catch JAX arrays. I.e. calling
That is exactly what I am currently doing but I though JAX might treat this as a regression considering that it worked in the past. |
I think the relevant change is #8043: starting in JAX 0.2.26, Here is the difference between JAX 0.2.25 and JAX 0.2.26: from matplotlib import cbook
import numpy as np
import jax.numpy as jnp
import jax
print(jax.__version__)
# 0.2.26
x = jnp.arange(2)
x_np = np.arange(2)
print(list(x))
# [DeviceArray(0, dtype=int32), DeviceArray(1, dtype=int32)]
# Utility used by hist():
print(cbook._reshape_2D(x, 'x'))
# [array([0], dtype=int32), array([1], dtype=int32)]
print(cbook._reshape_2D(x_np, 'x'))
# [array([0, 1], dtype=int32)] from matplotlib import cbook
import numpy as np
import jax.numpy as jnp
import jax
print(jax.__version__)
# 0.2.25
x = jnp.arange(2)
x_np = np.arange(2)
print(list(x))
# [0, 1]
# Utility used by hist():
print(cbook._reshape_2D(x, 'x'))
# [array([0, 1], dtype=int32)]
print(cbook._reshape_2D(x_np, 'x'))
# [array([0, 1])] I think where this falls afoul of matplotlib's assumptions is in this line: https://github.com/matplotlib/matplotlib/blob/c2a5e22baf6eeb4c5e49028d8ef10adf5a6f4f32/lib/matplotlib/cbook/__init__.py#L1392-L1393 JAX arrays are iterable, and numpy scalars are not. For what it's worth: you can cause the same failure mode without JAX by passing a list of zero-dimensional numpy arrays (with matplotlib version 3.4.1): fig, axs = plt.subplots(1, 2)
axs.flat[0].hist(np.linspace(0., 1.))
axs.flat[1].hist([np.array(i) for i in np.linspace(0., 1.)]); It looks like a bug in how matplotlib treats iterables of zero-dimensional arrays, and not a bug in JAX. In any case, my recommendation still stands: you should be careful passing JAX arrays to third-party packages which are built and tested assuming numpy array inputs. |
Thank you for the detailed explanation. I see that returning 0D DeviceArrays when iterating over a 1D DeviceArrays is desirable but "breaks" some edge-cases when one relies on the numpy behavior of the iterables of a 1D array being python native datatypes. |
This issue is now fixed in matplotlib matplotlib/matplotlib#22018 |
For some usages of JAX, the elements of an array seem to be transposed relative to NumPy leading to unexpected result. This is for example the case for matplotlib histograms.
The text was updated successfully, but these errors were encountered: