Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT invariance problem for large Python integers as arguments to function #6684

Closed
hawkinsp opened this issue May 7, 2021 · 13 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@hawkinsp
Copy link
Collaborator

hawkinsp commented May 7, 2021

In [1]: import jax, jax.numpy as jnp

In [2]: jnp.multiply(2 ** 100, 3.)  # doesn't crash
   ...:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Out[2]: DeviceArray(3.8029518e+30, dtype=float32)

In [3]: jax.jit(jnp.multiply)(2 ** 100, 3.)
...
OverflowError                             Traceback (most recent call last)
<ipython-input-3-2aed3283764c> in <module>
----> 1 jax.jit(jnp.multiply)(2 ** 100, 3.)

~/p/jax/jax/_src/dtypes.py in _scalar_type_to_dtype(typ, value)
    125   if typ is int and value is not None:
    126     if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
--> 127       raise OverflowError(f"Python int {value} too large to convert to {dtype}")
    128   return dtype
    129

OverflowError: Python int 1267650600228229401496703205376 too large to convert to int32
@hawkinsp hawkinsp added the bug Something isn't working label May 7, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented May 7, 2021

Looks like the problem is in _promote_dtypes:

In [1]: from jax._src.numpy.lax_numpy import _promote_dtypes                                                                 

In [2]: _promote_dtypes(2 ** 100, 3.)                                                                                        
Out[2]: [DeviceArray(1.2676506e+30, dtype=float32), DeviceArray(3., dtype=float32)]

@hawkinsp
Copy link
Collaborator Author

hawkinsp commented May 7, 2021

Also note this comes from #6165 which has a test in api_test.py.

@jakevdp
Copy link
Collaborator

jakevdp commented May 7, 2021

Root cause I think is a jit invariance issue in lax.convert_element_type:

In [1]: from jax import jit, lax                                                                                             

In [2]: lax.convert_element_type(2 ** 100, 'float32')                                                                        
Out[2]: DeviceArray(1.2676506e+30, dtype=float32)

In [3]: jit(lax.convert_element_type, static_argnums=1)(2 ** 100, 'float32') 
<snip>
OverflowError: Python int 1267650600228229401496703205376 too large to convert to int32

@hawkinsp
Copy link
Collaborator Author

hawkinsp commented May 7, 2021

Matt's thought on this was:
basically, when we device_put values for executing a jitted computation, we should look at the corresponding aval in the jaxpr
(rather than looking just at the type of the value we're putting, and the current value of the x64 flag)

@jakevdp
Copy link
Collaborator

jakevdp commented May 7, 2021

Looking at #6165 now... I wish I'd seen that earlier. I think I previously took out the specialization on python scalar input that #6165 added back in, specifically to prevent this kind of issue

@mattjj
Copy link
Collaborator

mattjj commented May 7, 2021

Yes but that removal (in #6014, right?) caused a regression in downstream user code, specifically in a Haiku test IIRC.

I think we can keep this behavior and have jit invariance without too much trouble. But I may be wrong! If that's wrong, we'll have to change the downstream user code instead. In either case, we'll end up with jit invariance here.

@jakevdp
Copy link
Collaborator

jakevdp commented May 7, 2021

I think we have three possible approaches to this jit invariance issue in lax.convert_element_type:

  1. revert fix convert_element_type on large Python int inputs #6165 so non-jitted code fails like jitted code
  2. entirely re-design how constants are lowered during tracing so that they are converted to the type eventually used in the jaxpr, rather than to the type corresponding to the variable the user passed
  3. Accept and document this jit invariance breakage

Thoughts?

@mattjj
Copy link
Collaborator

mattjj commented May 7, 2021

I want to try 2! After all, we already want it for other reasons: to make the x64 context manager work correctly with jit. (Btw see the jax-dev chat for more context on this.)

@mattjj
Copy link
Collaborator

mattjj commented May 7, 2021

The backup plan would be 1, but that'd require changing some downstream user code. (It might be a minor change though; hard to say without investigating it.)

@mattjj
Copy link
Collaborator

mattjj commented May 7, 2021

I think this was the test that started failing after #6014, which #6165 was trying to fix. Fixing the test wouldn't be hard, and maybe the error message would be pretty clear for users who had code written this way... so 1 is a pretty viable option too. I'm not sure how much real user code relies on this now. (I really hate fixing downstream code, which is part of why I want to try approach 2 first!)

@mattjj
Copy link
Collaborator

mattjj commented May 7, 2021

@jakevdp do you want to pursue plan 1 now (with a global presubmit check)? Then if that works out easily, i.e. there isn't much downstream code to fix, we can unblock jit invariance right away. If it doesn't work easily, at least we'll have more motivation for plan 2.

I plan to revise constant handling stuff for the x64 context manager, and I'll see if plan 2 can be folded into that easily. But this way we might not have to block on it.

WDYT?

@jakevdp
Copy link
Collaborator

jakevdp commented May 7, 2021

If we're going to pursue plan 2, then in this case I think we can leave this issue open for now and reduce churn. This seems like a pretty low-priority JIT invariance case, with an already good error message and a pretty clear fix (manually casting the input to the desired type)

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 29, 2022

I think the issue has been fixed:

In [1]: import jax, jax.numpy as jnp                                                                                                                                                                                                                                                                            

In [2]: jnp.multiply(2 ** 100, 3.)                                                                                                                                                                                                                                                                              
---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
<ipython-input-2-b332690c3005> in <module>
----> 1 jnp.multiply(2 ** 100, 3.)

    [... skipping hidden 6 frame]

~/github/google/jax/jax/_src/dtypes.py in _scalar_type_to_dtype(typ, value)
    149   if typ is int and value is not None:
    150     if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
--> 151       raise OverflowError(f"Python int {value} too large to convert to {dtype}")
    152   return dtype
    153 

OverflowError: Python int 1267650600228229401496703205376 too large to convert to int32

@jakevdp jakevdp closed this as completed Jun 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants