-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unexpected behavior in jax.numpy.array_equal #17621
Comments
Thanks for the report! See also #14901 I think the fix here would be to use I think it would not be correct to fall-back to python equality checks here (if you want that, just use |
I Think it might be better to fall-back to python equality but give out a warning about it, don't you think? |
No, I don't think accepting non-array inputs in |
Falling back to Python equality would also lead to strange consequences; for example, what should the output of this be? jnp.array_equal(Foo(1), np.arange(4)) If we fall back to python >>> Foo(1) == np.arange(4)
array([False, False, False, False]) when a scalar boolean is expected. |
Yes you are right, sorry. |
Hi @jheek This issue has been resolved with the PR #18708. I tested the mentioned code on Colab CPU and it now throws the following ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py](https://localhost:8080/#) in dtype(x, canonicalize)
658 try:
--> 659 dt = np.result_type(x)
660 except TypeError as err:
TypeError: Cannot interpret 'Foo(v=1)' as a data type
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
12 frames
TypeError: Cannot determine dtype of Foo(v=1)
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py](https://localhost:8080/#) in dtype(x, canonicalize)
661 raise TypeError(f"Cannot determine dtype of {x}") from err
662 if dt not in _jax_dtype_set and not issubdtype(dt, extended):
--> 663 raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
664 "type. Only arrays of numeric types are supported by JAX.")
665 # TODO(jakevdp): fix return type annotation and remove this ignore.
TypeError: Value 'Foo(v=1)' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX. Attaching the gist for reference. Thank you |
Thanks for following up! |
Description
The implementation of array_equal always returns False when the an input cannot be cast to an array using
jax.numpy.asarray
. This results in surprising equality checks when you accidentally feed an object:Normal numpy will revert to the equals implementation of the object. I think this or simply throwing an error if the inputs are not castable to arrays would be reasonable behavior.
What jax/jaxlib version are you using?
Any
Which accelerator(s) are you using?
Any
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: