-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT invariance problem for large Python integers as arguments to function #6684
Comments
Looks like the problem is in In [1]: from jax._src.numpy.lax_numpy import _promote_dtypes
In [2]: _promote_dtypes(2 ** 100, 3.)
Out[2]: [DeviceArray(1.2676506e+30, dtype=float32), DeviceArray(3., dtype=float32)] |
Also note this comes from #6165 which has a test in |
Root cause I think is a jit invariance issue in In [1]: from jax import jit, lax
In [2]: lax.convert_element_type(2 ** 100, 'float32')
Out[2]: DeviceArray(1.2676506e+30, dtype=float32)
In [3]: jit(lax.convert_element_type, static_argnums=1)(2 ** 100, 'float32')
<snip>
OverflowError: Python int 1267650600228229401496703205376 too large to convert to int32 |
Matt's thought on this was: |
Yes but that removal (in #6014, right?) caused a regression in downstream user code, specifically in a Haiku test IIRC. I think we can keep this behavior and have jit invariance without too much trouble. But I may be wrong! If that's wrong, we'll have to change the downstream user code instead. In either case, we'll end up with jit invariance here. |
I think we have three possible approaches to this jit invariance issue in
Thoughts? |
I want to try 2! After all, we already want it for other reasons: to make the x64 context manager work correctly with jit. (Btw see the jax-dev chat for more context on this.) |
The backup plan would be 1, but that'd require changing some downstream user code. (It might be a minor change though; hard to say without investigating it.) |
I think this was the test that started failing after #6014, which #6165 was trying to fix. Fixing the test wouldn't be hard, and maybe the error message would be pretty clear for users who had code written this way... so 1 is a pretty viable option too. I'm not sure how much real user code relies on this now. (I really hate fixing downstream code, which is part of why I want to try approach 2 first!) |
@jakevdp do you want to pursue plan 1 now (with a global presubmit check)? Then if that works out easily, i.e. there isn't much downstream code to fix, we can unblock jit invariance right away. If it doesn't work easily, at least we'll have more motivation for plan 2. I plan to revise constant handling stuff for the x64 context manager, and I'll see if plan 2 can be folded into that easily. But this way we might not have to block on it. WDYT? |
If we're going to pursue plan 2, then in this case I think we can leave this issue open for now and reduce churn. This seems like a pretty low-priority JIT invariance case, with an already good error message and a pretty clear fix (manually casting the input to the desired type) |
I think the issue has been fixed: In [1]: import jax, jax.numpy as jnp
In [2]: jnp.multiply(2 ** 100, 3.)
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
<ipython-input-2-b332690c3005> in <module>
----> 1 jnp.multiply(2 ** 100, 3.)
[... skipping hidden 6 frame]
~/github/google/jax/jax/_src/dtypes.py in _scalar_type_to_dtype(typ, value)
149 if typ is int and value is not None:
150 if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
--> 151 raise OverflowError(f"Python int {value} too large to convert to {dtype}")
152 return dtype
153
OverflowError: Python int 1267650600228229401496703205376 too large to convert to int32 |
The text was updated successfully, but these errors were encountered: