Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arrays having dtype float0 are broken with the Array API #20620

Open
NeilGirdhar opened this issue Apr 7, 2024 · 6 comments
Open

Arrays having dtype float0 are broken with the Array API #20620

NeilGirdhar opened this issue Apr 7, 2024 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Apr 7, 2024

Description

For all purposes, numpy arrays having dtype float0 (let's call them Z-arrays) are, in fact, Jax arrays. But these arrays are treated as NumPy arrays by array_api_compat.get_namespace leading to crashes when getting the namespace of multiple arrays. Therefore, something appears to be fundamentally broken.

Background

Z-arrays originate in various ways. E.g.,

  • When the gradient is zero, grad produces a Z-array for the gradient.
  • Z-arrays are required to be passed in for some cotangent values.

Does it really make sense to consider an array of dtype float0 to be a numpy array? The Numpy namespace doesn't even know anything about float0. It seems to me that an array with a given library's dtypes should be considered to belong to that library. Thus, I expect:

isinstance(np.zeros(10, np.float32), np.ndarray)  # True
isinstance(np.zeros(10, jax.float0), np.ndarray)  # False (but it's not!)
isinstance(jnp.zeros(10, jnp.float32), np.ndarray) # False
isinstance(np.zeros(10, np.float32), jax.Array) # False
isinstance(np.zeros(10, jax.float0), jax.Array) # True (but it's not!)
isinstance(jnp.zeros(10, jnp.float32), jax.Array) # True

Also, it is very odd that jnp.zeros_like applied to a Z-array (purportedly a numpy array according to get-namespace) returns a Z-array--a numpy array! An Array API function belonging to one namespace should probably never return an array belonging to another namespace?

Solutions

Some possible solutions:

  • Reconsider jnp.zeros does not support float0 as a dtype #4433 and have a Z-array be a Jax array of some new type that responds true to isinstance(z, jax.Array). This would fix the isinstance-problem above and the crashing.
  • Fix array_api_compat.get_namespace to consider Z-arrays to belong to the Jax namespace? This is the easier fix, which should at least resolve the crashing.

In my code, the easiest solution was to have a modified get_namespace, but I would prefer one of the above solutions so that I don't run into this problem with other libraries.

def fix_zero(x: Any) -> Any:
    if x.dtype == float0:
        return jnp.zeros(x.shape)
    return x

def get_namespace_fixed(self, *x):
        values = [fix_zero(getattr(self, field.name))
                  for field in fields(self)] + list(x)
        return get_namespace(*values)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.11.8 (main, Feb 22 2024, 17:25:49) [GCC 11.4.0]

@NeilGirdhar NeilGirdhar added the bug Something isn't working label Apr 7, 2024
@NeilGirdhar NeilGirdhar changed the title Gradient of jax arrays produces numpy arrays Arrays having dtype float0 are broken with the Array API Apr 7, 2024
@jakevdp jakevdp self-assigned this Apr 8, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 8, 2024

Thanks! It's not clear to me how to best handle this... I think @mattjj has been considering changing the representation of integer gradients to use extended dtypes rather than being marked by float0. That's probably the best way forward if we are going to fix this.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Apr 8, 2024

That sounds great. I just hope that whatever representation you choose derives from jax.Array, and is recognized by get_namespace as a Jax array 😄

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Apr 9, 2024

I will note that I get pretty scared by the idea of changing float0.

I've had to add careful handling for this in a few places, so I worry that changing this -- and breaking the existing code that handles it -- will lead to difficult-to-debug errors for any downstream users. It's such an edge case.

(Although FWIW I do also dislike the current regime, which isn't even self-consistent inside JAX! E.g. needing to return bool/int tangents when working with custom JVPs. If we are going to have a compatibility break can we do it all at once?)

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Apr 10, 2024

@patrick-kidger Just curious, but are you using the Array API yet?

@patrick-kidger
Copy link
Collaborator

Not yet!

NeilGirdhar added a commit to NeilGirdhar/array-api-compat that referenced this issue Apr 15, 2024
NeilGirdhar added a commit to NeilGirdhar/array-api-compat that referenced this issue Apr 15, 2024
NeilGirdhar added a commit to NeilGirdhar/array-api-compat that referenced this issue Apr 15, 2024
@NeilGirdhar
Copy link
Contributor Author

FYI: I modified array-API-compat to work around this problem for now, but it would still be better to smooth out the quirks on the Jax side when there's time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants