Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX elements transposed relative to NumPy #8899

Closed
Edenhofer opened this issue Dec 10, 2021 · 11 comments
Closed

JAX elements transposed relative to NumPy #8899

Edenhofer opened this issue Dec 10, 2021 · 11 comments
Labels
needs info More information is required to diagnose & prioritize the issue.

Comments

@Edenhofer
Copy link
Contributor

For some usages of JAX, the elements of an array seem to be transposed relative to NumPy leading to unexpected result. This is for example the case for matplotlib histograms.

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(1, 3)
axs.flat[0].hist(np.linspace(0., 1.))
axs.flat[0].set_title("Expected Result (NumPy)")
axs.flat[1].hist(jnp.linspace(0., 1.))
axs.flat[1].set_title("JAX Result")
axs.flat[2].hist(np.linspace(0., 1.).reshape(1, -1))
axs.flat[2].set_title("JAX Result w/ NumPy")
fig.tight_layout()
plt.show()

JAX_transpose_bug

@Edenhofer Edenhofer added the bug Something isn't working label Dec 10, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 10, 2021

What versions of jax and jaxlib are you using?

With the most recent jax and jaxlib releases, I find that the jax result matches the numpy result.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 10, 2021

The numpy version might be relevant as well, because the first thing matplotlib does is check np.isscalar(x) and reshape the input if it returns True: https://github.com/matplotlib/matplotlib/blob/5bb1449ced7653ba5357c43017ba6c8a04e1f21a/lib/matplotlib/axes/_axes.py#L6538

@jakevdp jakevdp added needs info More information is required to diagnose & prioritize the issue. and removed bug Something isn't working labels Dec 10, 2021
@Edenhofer
Copy link
Contributor Author

I ran the above code using the CPU backend and jax==0.2.26 but I also tested the most recent commit 83174dc14. I have jaxlib==0.1.75, numpy==1.21.3 and matplotlib==3.5.0 installed.

Furthermore, I tested the same code snippet on a different computer with jax==0.2.26, jaxlib==0.1.75, numpy==1.21.2 and matplotlib==3.4.3.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 10, 2021

I've narrowed it down to a matplotlib issue. In matplotlib 3.2.2 and older, the plot shows up correctly. In matplotlib 3.3.0 and newer, the plot is incorrect. You might look through the matplotlib changelog to figure out what's changed.

In the meantime, you should know that in general it is not 100% safe to pass JAX arrays to packages that expect numpy arrays as input, because those packages do not always do correct input validation (in this case, matplotlib fails to call np.asarray() before applying numpy functions like np.isscalar to the inputs). The best thing to do is call np.asarray yourself.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 10, 2021

I'm going to close, because I don't think is a JAX issue.

@jakevdp jakevdp closed this as completed Dec 10, 2021
@Edenhofer
Copy link
Contributor Author

I am pretty sure that this changed at some point when updating jax or jaxlib and I never had a matplotlib version older than 3.4.1 installed on my system.

@Edenhofer
Copy link
Contributor Author

I can reproduce the "correct" histogram with jax==0.2.20, jaxlib==0.1.71 and matplotlib==3.5.0.

@Edenhofer
Copy link
Contributor Author

In the meantime, you should know that in general it is not 100% safe to pass JAX arrays to packages that expect numpy arrays as input, because those packages do not always do correct input validation (in this case, matplotlib fails to call np.asarray() before applying numpy functions like np.isscalar to the inputs).

Sure, binary compatibility with numpy is not a sensible goal. However, matplotlib certainly is not any arbitrary library. Furthermore, the input validation you are referencing is not meant to catch JAX arrays. I.e. calling np.asarray on objects will most likely not be the solution here. np.isscalar works on many more objects than np.asarray, see here.

The best thing to do is call np.asarray yourself.

That is exactly what I am currently doing but I though JAX might treat this as a regression considering that it worked in the past.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 10, 2021

I think the relevant change is #8043: starting in JAX 0.2.26, iter(device_array) returns zero-dimensional jax arrays rather than numpy scalars.

Here is the difference between JAX 0.2.25 and JAX 0.2.26:

from matplotlib import cbook
import numpy as np
import jax.numpy as jnp
import jax

print(jax.__version__)
# 0.2.26

x = jnp.arange(2)
x_np = np.arange(2)
print(list(x))
# [DeviceArray(0, dtype=int32), DeviceArray(1, dtype=int32)]

# Utility used by hist():
print(cbook._reshape_2D(x, 'x'))
# [array([0], dtype=int32), array([1], dtype=int32)]

print(cbook._reshape_2D(x_np, 'x'))
# [array([0, 1], dtype=int32)]
from matplotlib import cbook
import numpy as np
import jax.numpy as jnp
import jax

print(jax.__version__)
# 0.2.25

x = jnp.arange(2)
x_np = np.arange(2)
print(list(x))
# [0, 1]

# Utility used by hist():
print(cbook._reshape_2D(x, 'x'))
# [array([0, 1], dtype=int32)]

print(cbook._reshape_2D(x_np, 'x'))
# [array([0, 1])]

I think where this falls afoul of matplotlib's assumptions is in this line: https://github.com/matplotlib/matplotlib/blob/c2a5e22baf6eeb4c5e49028d8ef10adf5a6f4f32/lib/matplotlib/cbook/__init__.py#L1392-L1393

JAX arrays are iterable, and numpy scalars are not.

For what it's worth: you can cause the same failure mode without JAX by passing a list of zero-dimensional numpy arrays (with matplotlib version 3.4.1):

fig, axs = plt.subplots(1, 2)
axs.flat[0].hist(np.linspace(0., 1.))
axs.flat[1].hist([np.array(i) for i in np.linspace(0., 1.)]);

download-2

It looks like a bug in how matplotlib treats iterables of zero-dimensional arrays, and not a bug in JAX. In any case, my recommendation still stands: you should be careful passing JAX arrays to third-party packages which are built and tested assuming numpy array inputs.

@Edenhofer
Copy link
Contributor Author

Thank you for the detailed explanation. I see that returning 0D DeviceArrays when iterating over a 1D DeviceArrays is desirable but "breaks" some edge-cases when one relies on the numpy behavior of the iterables of a 1D array being python native datatypes.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 20, 2021

This issue is now fixed in matplotlib matplotlib/matplotlib#22018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs info More information is required to diagnose & prioritize the issue.
Projects
None yet
Development

No branches or pull requests

2 participants