-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Arrays having dtype float0 are broken with the Array API #20620
Comments
Thanks! It's not clear to me how to best handle this... I think @mattjj has been considering changing the representation of integer gradients to use extended dtypes rather than being marked by float0. That's probably the best way forward if we are going to fix this. |
That sounds great. I just hope that whatever representation you choose derives from |
I will note that I get pretty scared by the idea of changing I've had to add careful handling for this in a few places, so I worry that changing this -- and breaking the existing code that handles it -- will lead to difficult-to-debug errors for any downstream users. It's such an edge case. (Although FWIW I do also dislike the current regime, which isn't even self-consistent inside JAX! E.g. needing to return bool/int tangents when working with custom JVPs. If we are going to have a compatibility break can we do it all at once?) |
@patrick-kidger Just curious, but are you using the Array API yet? |
Not yet! |
FYI: I modified array-API-compat to work around this problem for now, but it would still be better to smooth out the quirks on the Jax side when there's time. |
Description
For all purposes, numpy arrays having dtype
float0
(let's call them Z-arrays) are, in fact, Jax arrays. But these arrays are treated as NumPy arrays byarray_api_compat.get_namespace
leading to crashes when getting the namespace of multiple arrays. Therefore, something appears to be fundamentally broken.Background
Z-arrays originate in various ways. E.g.,
grad
produces a Z-array for the gradient.Does it really make sense to consider an array of dtype
float0
to be a numpy array? The Numpy namespace doesn't even know anything about float0. It seems to me that an array with a given library's dtypes should be considered to belong to that library. Thus, I expect:Also, it is very odd that
jnp.zeros_like
applied to a Z-array (purportedly a numpy array according to get-namespace) returns a Z-array--a numpy array! An Array API function belonging to one namespace should probably never return an array belonging to another namespace?Solutions
Some possible solutions:
isinstance(z, jax.Array)
. This would fix the isinstance-problem above and the crashing.array_api_compat.get_namespace
to consider Z-arrays to belong to the Jax namespace? This is the easier fix, which should at least resolve the crashing.In my code, the easiest solution was to have a modified
get_namespace
, but I would prefer one of the above solutions so that I don't run into this problem with other libraries.System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.11.8 (main, Feb 22 2024, 17:25:49) [GCC 11.4.0]
The text was updated successfully, but these errors were encountered: