Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected behavior in jax.numpy.array_equal #17621

Closed
jheek opened this issue Sep 15, 2023 · 7 comments
Closed

Unexpected behavior in jax.numpy.array_equal #17621

jheek opened this issue Sep 15, 2023 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@jheek
Copy link
Contributor

jheek commented Sep 15, 2023

Description

The implementation of array_equal always returns False when the an input cannot be cast to an array using jax.numpy.asarray. This results in surprising equality checks when you accidentally feed an object:

import dataclasses
@dataclasses.dataclass
class Foo:
  v: int

jax.numpy.array_equal(Foo(1), Foo(1)) # ==> False
jax.numpy.array_equal(Foo(1), Foo(2)) # ==> False

Normal numpy will revert to the equals implementation of the object. I think this or simply throwing an error if the inputs are not castable to arrays would be reasonable behavior.

What jax/jaxlib version are you using?

Any

Which accelerator(s) are you using?

Any

Additional system info

No response

NVIDIA GPU info

No response

@jheek jheek added the bug Something isn't working label Sep 15, 2023
kfirdev added a commit to kfirdev/jax that referenced this issue Sep 15, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 15, 2023

Thanks for the report! See also #14901

I think the fix here would be to use _check_arraylike on the inputs, which would effectively raise an error if the inputs can't be converted to a JAX array. We haven't done this yet because it's a breaking change that requires a deprecation cycle.

I think it would not be correct to fall-back to python equality checks here (if you want that, just use A == B to start with!). The reason this works for NumPy is because numpy converts such inputs into object arrays. JAX does not support object arrays, so raising an error would be more consistent in this case.

@kfirdev
Copy link

kfirdev commented Sep 15, 2023

I Think it might be better to fall-back to python equality but give out a warning about it, don't you think?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 15, 2023

No, I don't think accepting non-array inputs in array_equal is a good idea, for the reasons mentioned in #7737

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 15, 2023

Falling back to Python equality would also lead to strange consequences; for example, what should the output of this be?

jnp.array_equal(Foo(1), np.arange(4))

If we fall back to python ==, we get this:

>>> Foo(1) == np.arange(4)
array([False, False, False, False])

when a scalar boolean is expected.

@kfirdev
Copy link

kfirdev commented Sep 15, 2023

Yes you are right, sorry.

@rajasekharporeddy
Copy link
Contributor

Hi @jheek

This issue has been resolved with the PR #18708. I tested the mentioned code on Colab CPU and it now throws the following TypeError for Non-array inputs:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py](https://localhost:8080/#) in dtype(x, canonicalize)
    658     try:
--> 659       dt = np.result_type(x)
    660     except TypeError as err:

TypeError: Cannot interpret 'Foo(v=1)' as a data type

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
12 frames
TypeError: Cannot determine dtype of Foo(v=1)

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py](https://localhost:8080/#) in dtype(x, canonicalize)
    661       raise TypeError(f"Cannot determine dtype of {x}") from err
    662   if dt not in _jax_dtype_set and not issubdtype(dt, extended):
--> 663     raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
    664                     "type. Only arrays of numeric types are supported by JAX.")
    665   # TODO(jakevdp): fix return type annotation and remove this ignore.

TypeError: Value 'Foo(v=1)' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

Attaching the gist for reference.

Thank you

@jakevdp
Copy link
Collaborator

jakevdp commented May 28, 2024

Thanks for following up!

@jakevdp jakevdp closed this as completed May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants