You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Originally posted by DanPuzzuoli October 13, 2023
I'm trying to run a jit compiled gradient and I'm getting the following error:
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 256, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 167, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/core.py", line 2657, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/core.py", line 869, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 1212, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 1196, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 1132, in _pjit_call_impl_python
lowering_parameters=mlir.LoweringParameters()).compile()
^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2276, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2624, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2531, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/compiler.py", line 294, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/compiler.py", line 256, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: vector
This error does not get raised if I don't try to jit the gradient function, which makes it difficult to track down what's causing the error. I'm still trying to find a minimal example, but wanted to ask here in case anyone has any insight.
On my Macbook this produces:
libc++abi: terminating due to uncaught exception of type std::out_of_range: vector
Abort trap: 6
The reproduction comes from jax-ml/jax#18103 but I'm having trouble reproducing it in a unit test.
If this code path is triggered, a tuple with the incorrect size is returned.
Fixesjax-ml/jax#18106
PiperOrigin-RevId: 573666720
The reproduction comes from jax-ml/jax#18103 but I'm having trouble reproducing it in a unit test.
If this code path is triggered, a tuple with the incorrect size is returned.
Fixesjax-ml/jax#18106
PiperOrigin-RevId: 573666720
Discussed in #18103
Originally posted by DanPuzzuoli October 13, 2023
I'm trying to run a jit compiled gradient and I'm getting the following error:
This error does not get raised if I don't try to
jit
the gradient function, which makes it difficult to track down what's causing the error. I'm still trying to find a minimal example, but wanted to ask here in case anyone has any insight.On my Macbook this produces:
with this lldb backtrace:
The text was updated successfully, but these errors were encountered: